mirror of
https://github.com/langgenius/dify.git
synced 2026-04-06 06:21:27 +08:00
Compare commits
5 Commits
deploy/cle
...
refactor/b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed9623647e | ||
|
|
aa5a22991b | ||
|
|
4928917878 | ||
|
|
b00afff61e | ||
|
|
691248f477 |
@@ -204,16 +204,6 @@ When assigned to test a directory/path, test **ALL content** within that path:
|
||||
|
||||
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||
|
||||
### `nuqs` Query State Testing (Required for URL State Hooks)
|
||||
|
||||
When a component or hook uses `useQueryState` / `useQueryStates`:
|
||||
|
||||
- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`)
|
||||
- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`)
|
||||
- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values)
|
||||
- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable)
|
||||
- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test
|
||||
|
||||
## Core Principles
|
||||
|
||||
### 1. AAA Pattern (Arrange-Act-Assert)
|
||||
|
||||
@@ -80,9 +80,6 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||
- [ ] Router mocks match actual Next.js API
|
||||
- [ ] Mocks reflect actual component conditional behavior
|
||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||
- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`)
|
||||
- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`)
|
||||
- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values)
|
||||
|
||||
### Queries
|
||||
|
||||
|
||||
@@ -125,31 +125,6 @@ describe('Component', () => {
|
||||
})
|
||||
```
|
||||
|
||||
### 2.1 `nuqs` Query State (Preferred: Testing Adapter)
|
||||
|
||||
For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly.
|
||||
|
||||
```typescript
|
||||
import { renderHookWithNuqs } from '@/test/nuqs-testing'
|
||||
|
||||
it('should sync query to URL with push history', async () => {
|
||||
const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), {
|
||||
searchParams: '?page=1',
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ page: 2 })
|
||||
})
|
||||
|
||||
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
|
||||
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
|
||||
expect(update.options.history).toBe('push')
|
||||
expect(update.searchParams.get('page')).toBe('2')
|
||||
})
|
||||
```
|
||||
|
||||
Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope.
|
||||
|
||||
### 3. Portal Components (with Shared State)
|
||||
|
||||
```typescript
|
||||
|
||||
@@ -1,100 +1,43 @@
|
||||
---
|
||||
name: orpc-contract-first
|
||||
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
|
||||
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
|
||||
---
|
||||
|
||||
# oRPC Contract-First Development
|
||||
|
||||
## Intent
|
||||
## Project Structure
|
||||
|
||||
- Keep contract as single source of truth in `web/contract/*`.
|
||||
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
```
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
├── base.ts # Base contract (inputStructure: 'detailed')
|
||||
├── router.ts # Router composition & type exports
|
||||
├── marketplace.ts # Marketplace contracts
|
||||
└── console/ # Console contracts by domain
|
||||
├── system.ts
|
||||
└── billing.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
## Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
|
||||
- Use `base.route({...}).output(type<...>())` as baseline.
|
||||
- Add `.input(type<...>())` only when request has `params/query/body`.
|
||||
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
|
||||
2. Register contract in `web/contract/router.ts`
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utils.
|
||||
1. **Create contract** in `web/contract/console/{domain}.ts`
|
||||
- Import `base` from `../base` and `type` from `@orpc/contract`
|
||||
- Define route with `path`, `method`, `input`, `output`
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
2. **Register in router** at `web/contract/router.ts`
|
||||
- Import directly from domain file (no barrel files)
|
||||
- Nest by API prefix: `billing: { invoices, bindPartnerStack }`
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
3. **Create hooks** in `web/service/use-{domain}.ts`
|
||||
- Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
|
||||
- Use `consoleClient.{group}.{contract}()` for API calls
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default: call site directly uses `*.queryOptions(...)`.
|
||||
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration:
|
||||
- Combine multiple queries/mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
|
||||
|
||||
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
|
||||
|
||||
- `.key(...)`:
|
||||
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`:
|
||||
- Use for a specific query's full key (exact query identity / direct cache addressing).
|
||||
- `.mutationKey(...)`:
|
||||
- Use for a specific mutation's full key.
|
||||
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
## Key Rules
|
||||
|
||||
- **Input structure**: Always use `{ params, query?, body? }` format
|
||||
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
|
||||
- **Path params**: Use `{paramName}` in path, match in `params` object
|
||||
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
|
||||
- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`)
|
||||
- **No barrel files**: Import directly from specific files
|
||||
- **Types**: Import from `@/types/`, use `type<T>()` helper
|
||||
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
|
||||
|
||||
## Type Export
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
32
.github/dependabot.yml
vendored
32
.github/dependabot.yml
vendored
@@ -1,43 +1,25 @@
|
||||
version: 2
|
||||
|
||||
multi-ecosystem-groups:
|
||||
python:
|
||||
schedule:
|
||||
interval: "weekly" # or whatever schedule you want
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
patterns: ["*"]
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
python-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
patterns: ["*"]
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
uv-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
storybook:
|
||||
patterns:
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
npm-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
exclude-patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
2
.github/workflows/tool-test-sdks.yaml
vendored
2
.github/workflows/tool-test-sdks.yaml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: ''
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
@@ -457,7 +457,7 @@ jobs:
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -222,7 +222,6 @@ mise.toml
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
/.claude/worktrees/
|
||||
api/.env.backup
|
||||
/clickzetta
|
||||
|
||||
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
@@ -29,7 +29,7 @@ The codebase is split into:
|
||||
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
5
Makefile
5
Makefile
@@ -68,9 +68,8 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@echo "📝 Running type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
@@ -132,7 +131,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@@ -42,8 +42,6 @@ REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
# redis configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
# Optional: limit total connections in connection pool (unset for default)
|
||||
# REDIS_MAX_CONNECTIONS=200
|
||||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
|
||||
@@ -28,8 +28,17 @@ ignore_imports =
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_events
|
||||
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine.command_channels
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine.command_channels
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
|
||||
|
||||
@@ -49,6 +58,8 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
dify_graph.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
@@ -92,9 +103,13 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -138,7 +153,10 @@ ignore_imports =
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
dify_graph.nodes.human_input.human_input_node -> core.repositories.human_input_repository
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
|
||||
@@ -62,22 +62,6 @@ This is the default standard for backend code in this repo. Follow it for new co
|
||||
|
||||
- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
|
||||
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
|
||||
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
|
||||
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
|
||||
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
user_id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
|
||||
206
api/commands.py
206
api/commands.py
@@ -30,7 +30,6 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
@@ -937,12 +936,6 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
is_flag=True,
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
@click.option(
|
||||
"--task-label",
|
||||
default="daily",
|
||||
show_default=True,
|
||||
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
before_days: int,
|
||||
batch_size: int,
|
||||
@@ -951,13 +944,10 @@ def clean_workflow_runs(
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
task_label: str,
|
||||
):
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
@@ -977,17 +967,13 @@ def clean_workflow_runs(
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
@@ -2612,29 +2598,15 @@ def migrate_oss(
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=False,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=False,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
|
||||
)
|
||||
@click.option(
|
||||
"--before-days",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
|
||||
)
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
|
||||
@click.option(
|
||||
"--graceful-period",
|
||||
@@ -2643,99 +2615,33 @@ def migrate_oss(
|
||||
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
|
||||
)
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
|
||||
@click.option(
|
||||
"--task-label",
|
||||
default="daily",
|
||||
show_default=True,
|
||||
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
|
||||
)
|
||||
def clean_expired_messages(
|
||||
batch_size: int,
|
||||
graceful_period: int,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
from_days_ago: int | None,
|
||||
before_days: int | None,
|
||||
start_from: datetime.datetime,
|
||||
end_before: datetime.datetime,
|
||||
dry_run: bool,
|
||||
task_label: str,
|
||||
):
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
abs_mode = start_from is not None and end_before is not None
|
||||
rel_mode = before_days is not None
|
||||
|
||||
if abs_mode and rel_mode:
|
||||
raise click.UsageError(
|
||||
"Options are mutually exclusive: use either (--start-from,--end-before) "
|
||||
"or (--from-days-ago,--before-days)."
|
||||
)
|
||||
|
||||
if from_days_ago is not None and before_days is None:
|
||||
raise click.UsageError("--from-days-ago must be used together with --before-days.")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
|
||||
|
||||
if not abs_mode and not rel_mode:
|
||||
raise click.UsageError(
|
||||
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
|
||||
)
|
||||
|
||||
if rel_mode:
|
||||
assert before_days is not None
|
||||
if before_days < 0:
|
||||
raise click.UsageError("--before-days must be >= 0.")
|
||||
if from_days_ago is not None:
|
||||
if from_days_ago < 0:
|
||||
raise click.UsageError("--from-days-ago must be >= 0.")
|
||||
if from_days_ago <= before_days:
|
||||
raise click.UsageError("--from-days-ago must be greater than --before-days.")
|
||||
|
||||
# Create policy based on billing configuration
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
# Create and run the cleanup service
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
service = MessagesCleanService.from_days(
|
||||
policy=policy,
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
assert from_days_ago is not None
|
||||
now = naive_utc_now()
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=now - datetime.timedelta(days=from_days_ago),
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
@@ -2760,81 +2666,5 @@ def clean_expired_messages(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@@ -111,8 +111,3 @@ class RedisConfig(BaseSettings):
|
||||
description="Enable client side cache in redis",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_MAX_CONNECTIONS: PositiveInt | None = Field(
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Literal, Protocol
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -23,56 +23,41 @@ class RedisConfigDefaultsMixin:
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
Supported transports:
|
||||
- pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once)
|
||||
- sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling)
|
||||
- streams: Redis Streams (at-least-once, supports late subscribers)
|
||||
Configuration settings for Redis pub/sub streaming.
|
||||
"""
|
||||
|
||||
PUBSUB_REDIS_URL: str | None = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"),
|
||||
alias="PUBSUB_REDIS_URL",
|
||||
description=(
|
||||
"Redis connection URL for streaming events between API and celery worker; "
|
||||
"defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL."
|
||||
"Redis connection URL for pub/sub streaming events between API "
|
||||
"and celery worker, defaults to url constructed from "
|
||||
"`REDIS_*` configurations"
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
|
||||
description=(
|
||||
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
|
||||
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
|
||||
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
|
||||
"recommended to enable this for large deployments."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"),
|
||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
|
||||
description=(
|
||||
"Event transport type. Options are:\n\n"
|
||||
" - pubsub: normal Pub/Sub (at-most-once)\n"
|
||||
" - sharded: sharded Pub/Sub (at-most-once)\n"
|
||||
" - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n"
|
||||
"Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n"
|
||||
"Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n"
|
||||
"the risk of data loss from Redis auto-eviction under memory pressure.\n"
|
||||
"Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE."
|
||||
"Pub/sub channel type for streaming events. "
|
||||
"Valid options are:\n"
|
||||
"\n"
|
||||
" - pubsub: for normal Pub/Sub\n"
|
||||
" - sharded: for sharded Pub/Sub\n"
|
||||
"\n"
|
||||
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
|
||||
"for large deployments."
|
||||
),
|
||||
default="pubsub",
|
||||
)
|
||||
|
||||
PUBSUB_STREAMS_RETENTION_SECONDS: int = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"),
|
||||
description=(
|
||||
"When using 'streams', expire each stream key this many seconds after the last event is published. "
|
||||
"Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS."
|
||||
),
|
||||
default=600,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@@ -25,14 +23,14 @@ class AppParameterApi(InstalledAppResource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -185,4 +185,4 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return "", 204
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
@@ -35,14 +33,14 @@ class AppParameterApi(Resource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
ConversationDelete,
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
@@ -162,7 +163,7 @@ class ConversationDetailApi(Resource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return "", 204
|
||||
return ConversationDelete(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
||||
@@ -132,8 +132,6 @@ class WorkflowRunDetailApi(Resource):
|
||||
app_id=app_model.id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found.")
|
||||
return workflow_run
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -58,14 +57,14 @@ class AppParameterApi(WebApiResource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
|
||||
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
@@ -15,7 +12,7 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config", {}),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
|
||||
def convert(cls, config: dict) -> AgentEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -31,17 +28,17 @@ class AgentConfigManager:
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
tool_dict = cast(dict[str, Any], tool)
|
||||
if len(tool_dict) >= 4:
|
||||
if "enabled" not in tool_dict or not tool_dict["enabled"]:
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
"provider_type": tool_dict["provider_type"],
|
||||
"provider_id": tool_dict["provider_id"],
|
||||
"tool_name": tool_dict["tool_name"],
|
||||
"tool_parameters": tool_dict.get("tool_parameters", {}),
|
||||
"credential_id": tool_dict.get("credential_id", None),
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
@@ -50,8 +47,7 @@ class AgentConfigManager:
|
||||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt_raw = agent_dict.get("prompt", None)
|
||||
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
# check model mode
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
@@ -79,7 +75,7 @@ class AgentConfigManager:
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -8,13 +8,13 @@ from core.app.app_config.entities import (
|
||||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
|
||||
def convert(cls, config: dict) -> DatasetEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -25,15 +25,11 @@ class DatasetConfigManager:
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get("datasets", []):
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset["dataset"]
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
@@ -51,14 +47,15 @@ class DatasetConfigManager:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
for tool in agent_dict.get("tools", []):
|
||||
if len(tool) == 1:
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = cast(dict[str, Any], tool)[key]
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
@@ -5,13 +5,12 @@ from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
|
||||
def convert(cls, config: dict) -> ModelConfigEntity:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -23,7 +22,7 @@ class ModelConfigManager:
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get("completion_params") or {}
|
||||
completion_params = model_config.get("completion_params")
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
@@ -8,12 +6,12 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class PromptTemplateConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
|
||||
def convert(cls, config: dict) -> PromptTemplateEntity:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
@@ -42,15 +40,14 @@ class PromptTemplateConfigManager:
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params: dict[str, Any] = {
|
||||
completion_prompt_template_params = {
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
conv_role = completion_prompt_config.get("conversation_histories_role")
|
||||
if conv_role:
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": conv_role["user_prefix"],
|
||||
"assistant": conv_role["assistant_prefix"],
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
@@ -20,7 +18,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -53,9 +51,7 @@ class BasicVariablesConfigManager:
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable["variable"],
|
||||
type=variable.get("type", ""),
|
||||
config=variable.get("config", {}),
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
)
|
||||
)
|
||||
elif variable_type in {
|
||||
@@ -68,10 +64,10 @@ class BasicVariablesConfigManager:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=cast(VariableEntityType, variable_type),
|
||||
variable=variable["variable"],
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable["label"],
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
|
||||
@@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict[str, Any]
|
||||
app_model_config_dict: dict
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
dataset: DatasetEntity | None = None
|
||||
|
||||
@@ -516,10 +516,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
self,
|
||||
@@ -540,9 +538,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
@@ -740,6 +735,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
@@ -859,14 +855,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueWorkflowSucceededEvent():
|
||||
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueWorkflowPartialSuccessEvent():
|
||||
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
@@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
|
||||
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
||||
@@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
@@ -61,9 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
@@ -72,7 +70,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@@ -88,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
|
||||
"""
|
||||
Validate for agent chat app model config
|
||||
|
||||
@@ -159,7 +157,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
return filtered_config
|
||||
|
||||
@classmethod
|
||||
def validate_agent_mode_and_set_defaults(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@@ -15,7 +13,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
|
||||
|
||||
class ChatAppConfig(EasyUIBasedAppConfig):
|
||||
@@ -33,7 +31,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
@@ -66,7 +64,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@@ -81,7 +79,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate for chat app model config
|
||||
|
||||
@@ -147,4 +145,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
return filtered_config
|
||||
|
||||
@@ -173,10 +173,8 @@ class ChatAppRunner(AppRunner):
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@@ -10,7 +8,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
|
||||
|
||||
class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
@@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
@@ -42,9 +40,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
@@ -53,7 +49,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@@ -68,7 +64,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate for completion app model config
|
||||
|
||||
@@ -120,4 +116,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
return filtered_config
|
||||
|
||||
@@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
raise ValueError("Message app_model_config is None")
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params", {})
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
@@ -132,10 +132,8 @@ class CompletionAppRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
@@ -8,14 +8,12 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.enums import UserFrom, WorkflowType
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@@ -258,15 +256,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@@ -33,6 +33,7 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import UserFrom
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import (
|
||||
@@ -118,15 +119,13 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@@ -268,15 +267,13 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
@@ -7,7 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import InvokeFrom
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
@@ -15,69 +14,6 @@ if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class DifyRunContext(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
|
||||
def build_dify_run_context(
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
extra_context: Mapping[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build graph run_context with the reserved Dify runtime payload.
|
||||
|
||||
`extra_context` can carry user-defined context keys. The reserved `_dify`
|
||||
payload is always overwritten by this function to keep one canonical source.
|
||||
"""
|
||||
run_context = dict(extra_context) if extra_context else {}
|
||||
run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
return run_context
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
"""
|
||||
Model Config With Credentials Entity.
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Any, Union, cast
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -44,13 +44,14 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_file_utils import prepare_file_dict
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
@@ -218,14 +219,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@@ -459,40 +460,91 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
|
||||
# Fetch files associated with this message
|
||||
files = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
|
||||
if message_files:
|
||||
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
|
||||
upload_file_ids = list(
|
||||
dict.fromkeys(
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
)
|
||||
)
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
files_list = []
|
||||
for message_file in message_files:
|
||||
file_dict = prepare_file_dict(message_file, upload_files_map)
|
||||
files_list.append(file_dict)
|
||||
|
||||
files = files_list or None
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from threading import Thread, Timer
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -95,9 +96,9 @@ class MessageCycleManager:
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
thread = Timer(
|
||||
1,
|
||||
self._generate_conversation_name_worker,
|
||||
time.sleep(1)
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation_id,
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
:param message_file: MessageFile instance
|
||||
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
|
||||
:return: Dictionary containing file information
|
||||
"""
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
return {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
@@ -75,9 +75,8 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
try:
|
||||
dify_ctx = node.require_dify_context()
|
||||
deduct_llm_quota(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=node.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Final
|
||||
from typing import Final, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
@@ -201,7 +201,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
return cast(int, uuid_obj.int)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
|
||||
|
||||
@@ -120,8 +120,7 @@ class TencentTraceClient:
|
||||
|
||||
# Metrics exporter and instruments
|
||||
try:
|
||||
from opentelemetry.sdk.metrics import Histogram as SdkHistogram
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics import Histogram, MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||
|
||||
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
|
||||
@@ -129,7 +128,7 @@ class TencentTraceClient:
|
||||
use_http_json = protocol in {"http/json", "http-json"}
|
||||
|
||||
# Tencent APM works best with delta aggregation temporality
|
||||
preferred_temporality: dict[type, AggregationTemporality] = {SdkHistogram: AggregationTemporality.DELTA}
|
||||
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
|
||||
|
||||
def _create_metric_exporter(exporter_cls, **kwargs):
|
||||
"""Create metric exporter with preferred_temporality support"""
|
||||
|
||||
@@ -6,6 +6,7 @@ import hashlib
|
||||
import random
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
@@ -22,7 +23,7 @@ class TencentTraceUtils:
|
||||
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
return uuid_obj.int
|
||||
return cast(int, uuid_obj.int)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
@@ -51,9 +52,9 @@ class TencentTraceUtils:
|
||||
@staticmethod
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int
|
||||
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
|
||||
except (ValueError, TypeError):
|
||||
trace_id = uuid.uuid4().int
|
||||
trace_id = cast(int, uuid.uuid4().int)
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class ChromaVector(BaseVector):
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
@@ -73,7 +73,6 @@ class ChromaVector(BaseVector):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
return uuids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
|
||||
@@ -605,36 +605,25 @@ class ClickzettaVector(BaseVector):
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return []
|
||||
return
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
added_ids = []
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
batch_doc_ids = []
|
||||
for doc in batch_docs:
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))))
|
||||
added_ids.extend(batch_doc_ids)
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(
|
||||
self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches
|
||||
)
|
||||
|
||||
return added_ids
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
|
||||
def _insert_batch(
|
||||
self,
|
||||
batch_docs: list[Document],
|
||||
batch_embeddings: list[list[float]],
|
||||
batch_doc_ids: list[str],
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
@@ -652,9 +641,14 @@ class ClickzettaVector(BaseVector):
|
||||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids):
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
|
||||
@@ -4,10 +4,9 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
@@ -199,9 +198,12 @@ class _InvalidTimeoutStatusError(ValueError):
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_factory: sessionmaker | Engine,
|
||||
tenant_id: str,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _delivery_method_to_model(
|
||||
@@ -215,7 +217,7 @@ class HumanInputFormRepositoryImpl:
|
||||
id=delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=delivery_method.type,
|
||||
delivery_config_id=str(delivery_method.id),
|
||||
delivery_config_id=delivery_method.id,
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
@@ -341,7 +343,7 @@ class HumanInputFormRepositoryImpl:
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
start_time = naive_utc_now()
|
||||
@@ -433,7 +435,7 @@ class HumanInputFormRepositoryImpl:
|
||||
HumanInputForm.node_id == node_id,
|
||||
HumanInputForm.tenant_id == self._tenant_id,
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.scalars(form_query).first()
|
||||
if form_model is None:
|
||||
return None
|
||||
@@ -446,13 +448,18 @@ class HumanInputFormRepositoryImpl:
|
||||
class HumanInputFormSubmissionRepository:
|
||||
"""Repository for fetching and submitting human input forms."""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
@@ -471,7 +478,7 @@ class HumanInputFormSubmissionRepository:
|
||||
HumanInputFormRecipient.recipient_type == recipient_type,
|
||||
)
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
@@ -487,7 +494,7 @@ class HumanInputFormSubmissionRepository:
|
||||
submission_user_id: str | None,
|
||||
submission_end_user_id: str | None,
|
||||
) -> HumanInputFormRecord:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
@@ -517,7 +524,7 @@ class HumanInputFormSubmissionRepository:
|
||||
timeout_status: HumanInputFormStatus,
|
||||
reason: str | None = None,
|
||||
) -> HumanInputFormRecord:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
@@ -194,13 +194,6 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
|
||||
# Create a new database session
|
||||
with self._session_factory() as session:
|
||||
existing_model = session.get(WorkflowRun, db_model.id)
|
||||
if existing_model:
|
||||
if existing_model.tenant_id != self._tenant_id:
|
||||
raise ValueError("Unauthorized access to workflow run")
|
||||
# Preserve the original start time for pause/resume flows.
|
||||
db_model.created_at = existing_model.created_at
|
||||
|
||||
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||
# based on the presence of the primary key
|
||||
session.merge(db_model)
|
||||
|
||||
@@ -37,7 +37,6 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
|
||||
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import (
|
||||
@@ -20,10 +19,8 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.graph.graph import NodeFactory
|
||||
@@ -37,7 +34,6 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
|
||||
from dify_graph.nodes.datasource import DatasourceNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
|
||||
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
@@ -112,7 +108,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._dify_context = self._resolve_dify_context(graph_init_params.run_context)
|
||||
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
|
||||
self._code_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
@@ -144,16 +139,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
|
||||
raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
if isinstance(raw_ctx, DifyRunContext):
|
||||
return raw_ctx
|
||||
return DifyRunContext.model_validate(raw_ctx)
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -219,15 +205,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HUMAN_INPUT:
|
||||
return HumanInputNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
return KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
@@ -277,7 +254,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
@@ -368,7 +344,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@@ -5,26 +5,26 @@ from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from dify_graph.enums import UserFrom
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
@@ -34,66 +34,6 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
|
||||
"""
|
||||
Return whether `graph_config["nodes"]` contains the given node id.
|
||||
|
||||
Returns `None` when the nodes payload shape is unexpected, so graph-level
|
||||
validation can surface the original configuration error.
|
||||
"""
|
||||
nodes = graph_config.get("nodes")
|
||||
if not isinstance(nodes, list):
|
||||
return None
|
||||
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
return None
|
||||
current_id = node.get("id")
|
||||
if isinstance(current_id, str) and current_id == node_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def build_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> GraphEngine:
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id)
|
||||
if has_root_node is False:
|
||||
raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found")
|
||||
|
||||
child_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
child_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=child_graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
child_engine_builder=self,
|
||||
)
|
||||
child_engine.layer(LLMQuotaLayer())
|
||||
for layer in layers:
|
||||
child_engine.layer(cast(GraphEngineLayer, layer))
|
||||
return child_engine
|
||||
|
||||
|
||||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -137,7 +77,6 @@ class WorkflowEntry:
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
self.command_channel = command_channel
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder()
|
||||
self.graph_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=graph,
|
||||
@@ -149,7 +88,6 @@ class WorkflowEntry:
|
||||
scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD,
|
||||
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
|
||||
),
|
||||
child_engine_builder=self._child_engine_builder,
|
||||
)
|
||||
|
||||
# Add debug logging layer when in debug mode
|
||||
@@ -216,15 +154,13 @@ class WorkflowEntry:
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
@@ -357,15 +293,13 @@ class WorkflowEntry:
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
from dify_graph.enums import InvokeFrom, UserFrom
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
@@ -18,7 +18,11 @@ class GraphInitParams(BaseModel):
|
||||
"""
|
||||
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
run_context: Mapping[str, Any] = Field(..., description="runtime context")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
|
||||
@@ -33,6 +33,39 @@ class SystemVariableKey(StrEnum):
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""Get source of invoke from.
|
||||
|
||||
:return: source
|
||||
"""
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
|
||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from dify_graph.context import capture_current_context
|
||||
@@ -27,7 +27,6 @@ from dify_graph.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from dify_graph.runtime.graph_runtime_state import GraphProtocol
|
||||
@@ -50,7 +49,6 @@ from .protocols.command_channel import CommandChannel
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
|
||||
from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
@@ -76,7 +74,6 @@ class GraphEngine:
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
@@ -86,9 +83,6 @@ class GraphEngine:
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
self._child_engine_builder = child_engine_builder
|
||||
if child_engine_builder is not None:
|
||||
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
@@ -220,25 +214,6 @@ class GraphEngine:
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: dict[str, object] | Mapping[str, object],
|
||||
root_node_id: str,
|
||||
layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (),
|
||||
) -> GraphEngine:
|
||||
return self._graph_runtime_state.create_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Execute the graph using the modular architecture.
|
||||
|
||||
@@ -80,11 +80,9 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
@@ -122,8 +120,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
@@ -146,8 +144,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
@@ -285,13 +283,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
@@ -403,8 +396,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
@@ -425,11 +417,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
@@ -440,10 +429,9 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
@@ -452,7 +440,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
|
||||
@@ -8,11 +8,10 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@@ -65,28 +64,10 @@ from libs.datetime_utils import naive_utc_now
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyRunContextProtocol(Protocol):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: Any
|
||||
invoke_from: Any
|
||||
|
||||
|
||||
class _MappingDifyRunContext:
|
||||
def __init__(self, mapping: Mapping[str, Any]) -> None:
|
||||
self.tenant_id = str(mapping["tenant_id"])
|
||||
self.app_id = str(mapping["app_id"])
|
||||
self.user_id = str(mapping["user_id"])
|
||||
self.user_from = mapping["user_from"]
|
||||
self.invoke_from = mapping["invoke_from"]
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
@@ -246,10 +227,14 @@ class Node(Generic[NodeDataT]):
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self._run_context = MappingProxyType(dict(graph_init_params.run_context))
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
@@ -278,38 +263,6 @@ class Node(Generic[NodeDataT]):
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
def run_context(self) -> Mapping[str, Any]:
|
||||
return self._run_context
|
||||
|
||||
def get_run_context_value(self, key: str, default: Any = None) -> Any:
|
||||
return self._run_context.get(key, default)
|
||||
|
||||
def require_run_context_value(self, key: str) -> Any:
|
||||
value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE)
|
||||
if value is _MISSING_RUN_CONTEXT_VALUE:
|
||||
raise ValueError(f"run_context missing required key: {key}")
|
||||
return value
|
||||
|
||||
def require_dify_context(self) -> DifyRunContextProtocol:
|
||||
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
|
||||
if isinstance(raw_ctx, Mapping):
|
||||
missing_keys = [
|
||||
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
|
||||
]
|
||||
if missing_keys:
|
||||
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
|
||||
return _MappingDifyRunContext(raw_ctx)
|
||||
|
||||
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
|
||||
if not hasattr(raw_ctx, attr):
|
||||
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
|
||||
|
||||
return cast(DifyRunContextProtocol, raw_ctx)
|
||||
|
||||
@property
|
||||
def execution_id(self) -> str:
|
||||
return self._node_execution_id
|
||||
|
||||
@@ -52,7 +52,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
@@ -76,7 +75,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
||||
provider_id=provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=datasource_type.value,
|
||||
)
|
||||
|
||||
@@ -105,11 +104,11 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
yield from self.datasource_manager.stream_node_events(
|
||||
node_id=self._node_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
user_id=self.user_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
datasource_type=datasource_type.value,
|
||||
provider_id=provider_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=credential_id,
|
||||
@@ -137,7 +136,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
|
||||
file_info = self.datasource_manager.get_upload_file_by_id(
|
||||
file_id=related_id, tenant_id=dify_ctx.tenant_id
|
||||
file_id=related_id, tenant_id=self.tenant_id
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -21,11 +20,11 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod, file_manager
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.variables import ArrayFileSegment
|
||||
from dify_graph.variables.segments import ArrayStringSegment, FileSegment
|
||||
|
||||
@@ -59,7 +58,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
http_client: HttpClientProtocol,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -68,7 +66,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
|
||||
self._http_client = http_client
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
@@ -83,24 +80,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
|
||||
value = variable.value
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
if isinstance(value, list):
|
||||
value = list(filter(lambda x: x, value))
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
if not value:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": ArrayStringSegment(value=[])},
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = [
|
||||
_extract_text_from_file(
|
||||
self._http_client, file, unstructured_api_config=self._unstructured_api_config
|
||||
)
|
||||
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
|
||||
for file in value
|
||||
]
|
||||
return NodeRunResult(
|
||||
@@ -110,9 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(
|
||||
self._http_client, value, unstructured_api_config=self._unstructured_api_config
|
||||
)
|
||||
extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -122,7 +105,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
else:
|
||||
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
|
||||
except DocumentExtractorError as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
@@ -397,32 +379,6 @@ def parser_docx_part(block, doc: Document, content_items, i):
|
||||
content_items.append((i, "table", Table(block, doc)))
|
||||
|
||||
|
||||
def _normalize_docx_zip(file_content: bytes) -> bytes:
|
||||
"""
|
||||
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
|
||||
ZIP entry names use backslash (\\) as path separator instead of the forward
|
||||
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
|
||||
"word\\document.xml" is never found when python-docx looks for
|
||||
"word/document.xml", which triggers a KeyError about a missing relationship.
|
||||
|
||||
This function rewrites the ZIP in-memory, normalizing all entry names to
|
||||
use forward slashes without touching any actual document content.
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
|
||||
out_buf = io.BytesIO()
|
||||
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
|
||||
for item in zin.infolist():
|
||||
data = zin.read(item.filename)
|
||||
# Normalize backslash path separators to forward slash
|
||||
item.filename = item.filename.replace("\\", "/")
|
||||
zout.writestr(item, data)
|
||||
return out_buf.getvalue()
|
||||
except zipfile.BadZipFile:
|
||||
# Not a valid zip — return as-is and let python-docx report the real error
|
||||
return file_content
|
||||
|
||||
|
||||
def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOCX file.
|
||||
@@ -430,15 +386,7 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
try:
|
||||
doc = docx.Document(doc_file)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
|
||||
# Some DOCX files exported by tools like Evernote on Windows use
|
||||
# backslash path separators in ZIP entries and/or single-quoted XML
|
||||
# attributes, both of which break python-docx on Linux. Normalize and retry.
|
||||
file_content = _normalize_docx_zip(file_content)
|
||||
doc = docx.Document(io.BytesIO(file_content))
|
||||
doc = docx.Document(doc_file)
|
||||
text = []
|
||||
|
||||
# Keep track of paragraph and table positions
|
||||
@@ -491,13 +439,13 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes:
|
||||
def _download_file_content(file: File) -> bytes:
|
||||
"""Download the content of a file based on its transfer method."""
|
||||
try:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if file.remote_url is None:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = http_client.get(file.remote_url)
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
@@ -506,10 +454,8 @@ def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_file(
|
||||
http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig
|
||||
) -> str:
|
||||
file_content = _download_file_content(http_client, file)
|
||||
def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
file_content = _download_file_content(file)
|
||||
if file.extension:
|
||||
extracted_text = _extract_text_by_file_extension(
|
||||
file_content=file_content,
|
||||
|
||||
@@ -212,7 +212,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
dify_ctx = self.require_dify_context()
|
||||
files: list[File] = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
@@ -237,8 +236,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
tool_file_manager = self._tool_file_manager_factory()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=mime_type,
|
||||
@@ -250,7 +249,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
@@ -20,6 +21,7 @@ from dify_graph.repositories.human_input_form_repository import (
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
|
||||
@@ -31,8 +33,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
_INVOKE_FROM_DEBUGGER = "debugger"
|
||||
_INVOKE_FROM_EXPLORE = "explore"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,7 +66,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository,
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -74,6 +74,11 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if form_repository is None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
|
||||
@classmethod
|
||||
@@ -157,39 +162,30 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return resolved_defaults
|
||||
|
||||
def _should_require_console_recipient(self) -> bool:
|
||||
invoke_from = self._invoke_from_value()
|
||||
if invoke_from == _INVOKE_FROM_DEBUGGER:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
if invoke_from == _INVOKE_FROM_EXPLORE:
|
||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
||||
return self._node_data.is_webapp_enabled()
|
||||
return False
|
||||
|
||||
def _display_in_ui(self) -> bool:
|
||||
if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
invoke_from = self._invoke_from_value()
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
|
||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
|
||||
user_id=dify_ctx.user_id,
|
||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
||||
user_id=self.user_id or "",
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
|
||||
def _invoke_from_value(self) -> str:
|
||||
invoke_from = self.require_dify_context().invoke_from
|
||||
if isinstance(invoke_from, str):
|
||||
return invoke_from
|
||||
return str(getattr(invoke_from, "value", invoke_from))
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_default_values = self.resolve_default_values()
|
||||
@@ -223,11 +219,10 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
dify_ctx = self.require_dify_context()
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=dify_ctx.app_id,
|
||||
app_id=self.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
@@ -237,9 +232,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
dify_ctx.user_id
|
||||
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
|
||||
else None
|
||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
||||
),
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
@@ -587,14 +587,24 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams for child graph execution.
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
run_context=self.run_context,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
# Create a deep copy of the variable pool for each iteration
|
||||
@@ -611,17 +621,28 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
root_node_id = self.node_data.start_node_id
|
||||
if root_node_id is None:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
try:
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
except ChildGraphNotFoundError as exc:
|
||||
raise IterationGraphNotFoundError("iteration graph not found") from exc
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=iteration_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.template import Template
|
||||
@@ -20,7 +20,6 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_INVOKE_FROM_DEBUGGER = "debugger"
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
@@ -59,8 +58,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
invoke_from_value = str(invoke_from.value) if invoke_from else None
|
||||
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
|
||||
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
|
||||
@@ -23,11 +23,7 @@ from dify_graph.variables import (
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayObjectSegment
|
||||
|
||||
from .entities import (
|
||||
Condition,
|
||||
KnowledgeRetrievalNodeData,
|
||||
MetadataFilteringCondition,
|
||||
)
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
@@ -70,10 +66,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@@ -120,7 +115,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@@ -165,7 +160,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
@@ -175,12 +169,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if node_data.metadata_filtering_mode is not None:
|
||||
metadata_filtering_mode = node_data.metadata_filtering_mode
|
||||
|
||||
resolved_metadata_conditions = (
|
||||
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
|
||||
if node_data.metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
@@ -188,10 +176,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
model = node_data.single_retrieval_config.model
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
user_from=dify_ctx.user_from.value,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
||||
completion_params=model.completion_params,
|
||||
@@ -199,7 +187,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
model_mode=model.mode,
|
||||
model_name=model.name,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=resolved_metadata_conditions,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
query=query,
|
||||
)
|
||||
@@ -241,10 +229,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
app_id=dify_ctx.app_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
user_from=dify_ctx.user_from.value,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
||||
@@ -257,7 +245,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=resolved_metadata_conditions,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
@@ -266,48 +254,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
usage = self._rag_retrieval.llm_usage
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _resolve_metadata_filtering_conditions(
|
||||
self, conditions: MetadataFilteringCondition
|
||||
) -> MetadataFilteringCondition:
|
||||
if conditions.conditions is None:
|
||||
return MetadataFilteringCondition(
|
||||
logical_operator=conditions.logical_operator,
|
||||
conditions=None,
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
else:
|
||||
resolved_value = value
|
||||
resolved_conditions.append(
|
||||
Condition(
|
||||
name=cond.name,
|
||||
comparison_operator=cond.comparison_operator,
|
||||
value=resolved_value,
|
||||
)
|
||||
)
|
||||
return MetadataFilteringCondition(
|
||||
logical_operator=conditions.logical_operator or "and",
|
||||
conditions=resolved_conditions,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -145,10 +145,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@@ -243,7 +242,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
@@ -703,7 +702,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.require_dify_context().tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
|
||||
@@ -412,14 +412,24 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams for child graph execution.
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
run_context=self.run_context,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
@@ -429,10 +439,22 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the loop graph with the new node factory
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=loop_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
tools=tools,
|
||||
stop=list(stop),
|
||||
stream=False,
|
||||
user=self.require_dify_context().user_id,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
@@ -86,10 +86,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@@ -161,7 +160,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
|
||||
@@ -56,8 +56,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
@@ -77,12 +75,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
self._node_id,
|
||||
self.node_data,
|
||||
dify_ctx.invoke_from,
|
||||
variable_pool,
|
||||
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -116,10 +109,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=dify_ctx.app_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
@@ -140,8 +133,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
|
||||
@@ -69,7 +69,6 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
dify_ctx = self.require_dify_context()
|
||||
related_id = file.get("related_id")
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
@@ -85,7 +84,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
try:
|
||||
file_obj = file_factory.build_from_mapping(
|
||||
mapping=file,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
|
||||
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
from .graph_runtime_state import (
|
||||
ChildEngineBuilderNotConfiguredError,
|
||||
ChildEngineError,
|
||||
ChildGraphNotFoundError,
|
||||
GraphRuntimeState,
|
||||
)
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
|
||||
__all__ = [
|
||||
"ChildEngineBuilderNotConfiguredError",
|
||||
"ChildEngineError",
|
||||
"ChildGraphNotFoundError",
|
||||
"GraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
|
||||
@@ -15,7 +15,6 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
@@ -136,31 +135,6 @@ class GraphProtocol(Protocol):
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class ChildGraphEngineBuilderProtocol(Protocol):
|
||||
def build_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ChildEngineError(ValueError):
|
||||
"""Base error type for child-engine creation failures."""
|
||||
|
||||
|
||||
class ChildEngineBuilderNotConfiguredError(ChildEngineError):
|
||||
"""Raised when child-engine creation is requested without a bound builder."""
|
||||
|
||||
|
||||
class ChildGraphNotFoundError(ChildEngineError):
|
||||
"""Raised when the requested child graph entry point cannot be resolved."""
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
@@ -235,7 +209,6 @@ class GraphRuntimeState:
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
@@ -277,31 +250,6 @@ class GraphRuntimeState:
|
||||
if self._graph is not None:
|
||||
_ = self.response_coordinator
|
||||
|
||||
def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None:
|
||||
self._child_engine_builder = builder
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> Any:
|
||||
if self._child_engine_builder is None:
|
||||
raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.")
|
||||
|
||||
return self._child_engine_builder.build_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Primary collaborators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -65,15 +65,9 @@ class VariablePool(BaseModel):
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add conversation variables to the variable pool. When restoring from a serialized
|
||||
# snapshot, `variable_dictionary` already carries the latest runtime values.
|
||||
# In that case, keep existing entries instead of overwriting them with the
|
||||
# bootstrap list.
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables:
|
||||
selector = (CONVERSATION_VARIABLE_NODE_ID, var.name)
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, var)
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
|
||||
@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
@@ -56,11 +54,9 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s
|
||||
continue
|
||||
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = cast(dict[str, Any], list(tool.values())[0])
|
||||
tool_config = list(tool.values())[0]
|
||||
if tool_type == "dataset":
|
||||
dataset_id = tool_config.get("id")
|
||||
if isinstance(dataset_id, str):
|
||||
dataset_ids.add(dataset_id)
|
||||
dataset_ids.add(tool_config.get("id"))
|
||||
|
||||
# get dataset from dataset_configs
|
||||
dataset_configs = app_model_config.dataset_configs_dict
|
||||
|
||||
@@ -13,7 +13,6 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@@ -67,7 +66,6 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -18,7 +18,6 @@ from dify_app import DifyApp
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
@@ -182,18 +181,13 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
|
||||
|
||||
sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")]
|
||||
|
||||
sentinel_kwargs = {
|
||||
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
||||
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
||||
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
||||
}
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
sentinel_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
|
||||
sentinel = Sentinel(
|
||||
sentinel_hosts,
|
||||
sentinel_kwargs=sentinel_kwargs,
|
||||
sentinel_kwargs={
|
||||
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
||||
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
||||
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
||||
},
|
||||
)
|
||||
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
@@ -210,15 +204,12 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||
]
|
||||
|
||||
cluster_kwargs: dict[str, Any] = {
|
||||
"startup_nodes": nodes,
|
||||
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
}
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
cluster: RedisCluster = RedisCluster(**cluster_kwargs)
|
||||
cluster: RedisCluster = RedisCluster(
|
||||
startup_nodes=nodes,
|
||||
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
cache_config=_get_cache_configuration(),
|
||||
)
|
||||
return cluster
|
||||
|
||||
|
||||
@@ -234,9 +225,6 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
}
|
||||
)
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
|
||||
if ssl_kwargs:
|
||||
redis_params.update(ssl_kwargs)
|
||||
|
||||
@@ -246,17 +234,9 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
|
||||
max_conns = dify_config.REDIS_MAX_CONNECTIONS
|
||||
if use_clusters:
|
||||
if max_conns:
|
||||
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
|
||||
else:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
|
||||
if max_conns:
|
||||
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
|
||||
else:
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
@@ -289,11 +269,6 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||
return StreamsBroadcastChannel(
|
||||
_pubsub_redis_client,
|
||||
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||
)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3Format
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
@@ -31,29 +31,9 @@ def setup_context_propagation() -> None:
|
||||
|
||||
|
||||
def shutdown_tracer() -> None:
|
||||
flush_telemetry()
|
||||
|
||||
|
||||
def flush_telemetry() -> None:
|
||||
"""
|
||||
Best-effort flush for telemetry providers.
|
||||
|
||||
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
|
||||
so counters/histograms are exported before the process exits.
|
||||
"""
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
try:
|
||||
provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush trace provider")
|
||||
|
||||
metric_provider = metrics.get_meter_provider()
|
||||
if hasattr(metric_provider, "force_flush"):
|
||||
try:
|
||||
metric_provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush metric provider")
|
||||
provider.force_flush()
|
||||
|
||||
|
||||
def is_celery_worker():
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamsBroadcastChannel:
|
||||
"""
|
||||
Redis Streams based broadcast channel implementation.
|
||||
|
||||
Characteristics:
|
||||
- At-least-once delivery for late subscribers within the stream retention window.
|
||||
- Each topic is stored as a dedicated Redis Stream key.
|
||||
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
|
||||
self._client = redis_client
|
||||
self._retention_seconds = max(int(retention_seconds or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> StreamsTopic:
|
||||
return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds)
|
||||
|
||||
|
||||
class StreamsTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = f"stream:{topic}"
|
||||
self._retention_seconds = retention_seconds
|
||||
self.max_length = 5000
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length)
|
||||
if self._retention_seconds > 0:
|
||||
try:
|
||||
self._client.expire(self._key, self._retention_seconds)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _StreamsSubscription(self._client, self._key)
|
||||
|
||||
|
||||
class _StreamsSubscription(Subscription):
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||
self._client = client
|
||||
self._key = key
|
||||
self._closed = threading.Event()
|
||||
self._last_id = "0-0"
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
self._start_lock = threading.Lock()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
def _listen(self) -> None:
|
||||
try:
|
||||
while not self._closed.is_set():
|
||||
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
|
||||
|
||||
if not streams:
|
||||
continue
|
||||
|
||||
for _key, entries in streams:
|
||||
for entry_id, fields in entries:
|
||||
data = None
|
||||
if isinstance(fields, dict):
|
||||
data = fields.get(b"data")
|
||||
data_bytes: bytes | None = None
|
||||
if isinstance(data, str):
|
||||
data_bytes = data.encode()
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
self._queue.put_nowait(data_bytes)
|
||||
self._last_id = entry_id
|
||||
finally:
|
||||
self._queue.put_nowait(self._SENTINEL)
|
||||
self._listener = None
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
if self._listener is not None:
|
||||
return
|
||||
# Ensure only one listener thread is created under concurrent calls
|
||||
with self._start_lock:
|
||||
if self._listener is not None or self._closed.is_set():
|
||||
return
|
||||
self._listener = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-streams-sub-{self._key}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener.start()
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
# Iterator delegates to receive with timeout; stops on closure.
|
||||
self._start_if_needed()
|
||||
while not self._closed.is_set():
|
||||
item = self.receive(timeout=1)
|
||||
if item is not None:
|
||||
yield item
|
||||
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
if timeout is None:
|
||||
item = self._queue.get()
|
||||
else:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
if item is self._SENTINEL or self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||
return bytes(item)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
self._closed.set()
|
||||
listener = self._listener
|
||||
if listener is not None:
|
||||
listener.join(timeout=2.0)
|
||||
if listener.is_alive():
|
||||
logger.warning(
|
||||
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
|
||||
self._key,
|
||||
)
|
||||
else:
|
||||
self._listener = None
|
||||
|
||||
# Context manager helpers
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
@@ -66,7 +66,6 @@ def run_migrations_offline():
|
||||
context.configure(
|
||||
url=url, target_metadata=get_metadata(), literal_binds=True
|
||||
)
|
||||
logger.info("Generating offline migration SQL with url: %s", url)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""add partial indexes on conversations for app_id with created_at and updated_at
|
||||
|
||||
Revision ID: e288952f2994
|
||||
Revises: fce013ca180e
|
||||
Create Date: 2026-02-26 13:36:45.928922
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e288952f2994'
|
||||
down_revision = 'fce013ca180e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
'conversation_app_created_at_idx',
|
||||
['app_id', sa.literal_column('created_at DESC')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_deleted IS false'),
|
||||
)
|
||||
batch_op.create_index(
|
||||
'conversation_app_updated_at_idx',
|
||||
['app_id', sa.literal_column('updated_at DESC')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_deleted IS false'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.drop_index('conversation_app_updated_at_idx')
|
||||
batch_op.drop_index('conversation_app_created_at_idx')
|
||||
@@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -15,7 +15,6 @@ from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
@@ -37,259 +36,6 @@ if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
||||
|
||||
# --- TypedDict definitions for structured dict return types ---
|
||||
|
||||
|
||||
class EnabledConfig(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EmbeddingModelInfo(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationReplyDisabledConfig(TypedDict):
|
||||
enabled: Literal[False]
|
||||
|
||||
|
||||
class AnnotationReplyEnabledConfig(TypedDict):
|
||||
id: str
|
||||
enabled: Literal[True]
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelInfo
|
||||
|
||||
|
||||
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfig(TypedDict):
|
||||
enabled: bool
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class AgentToolConfig(TypedDict):
|
||||
provider_type: str
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any]
|
||||
plugin_unique_identifier: NotRequired[str | None]
|
||||
credential_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class AgentModeConfig(TypedDict):
|
||||
enabled: bool
|
||||
strategy: str | None
|
||||
tools: list[AgentToolConfig | dict[str, Any]]
|
||||
prompt: str | None
|
||||
|
||||
|
||||
class ImageUploadConfig(TypedDict):
|
||||
enabled: bool
|
||||
number_limits: int
|
||||
detail: str
|
||||
transfer_methods: list[str]
|
||||
|
||||
|
||||
class FileUploadConfig(TypedDict):
|
||||
image: ImageUploadConfig
|
||||
|
||||
|
||||
class DeletedToolInfo(TypedDict):
|
||||
type: str
|
||||
tool_name: str
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ExternalDataToolConfig(TypedDict):
|
||||
enabled: bool
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class UserInputFormItemConfig(TypedDict):
|
||||
variable: str
|
||||
label: str
|
||||
description: NotRequired[str]
|
||||
required: NotRequired[bool]
|
||||
max_length: NotRequired[int]
|
||||
options: NotRequired[list[str]]
|
||||
default: NotRequired[str]
|
||||
type: NotRequired[str]
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
|
||||
UserInputFormItem = dict[str, UserInputFormItemConfig]
|
||||
|
||||
|
||||
class DatasetConfigs(TypedDict):
|
||||
retrieval_model: str
|
||||
datasets: NotRequired[dict[str, Any]]
|
||||
top_k: NotRequired[int]
|
||||
score_threshold: NotRequired[float]
|
||||
score_threshold_enabled: NotRequired[bool]
|
||||
reranking_model: NotRequired[dict[str, Any] | None]
|
||||
weights: NotRequired[dict[str, Any] | None]
|
||||
reranking_enabled: NotRequired[bool]
|
||||
reranking_mode: NotRequired[str]
|
||||
metadata_filtering_mode: NotRequired[str]
|
||||
metadata_model_config: NotRequired[dict[str, Any] | None]
|
||||
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
|
||||
|
||||
|
||||
class ChatPromptMessage(TypedDict):
|
||||
text: str
|
||||
role: str
|
||||
|
||||
|
||||
class ChatPromptConfig(TypedDict, total=False):
|
||||
prompt: list[ChatPromptMessage]
|
||||
|
||||
|
||||
class CompletionPromptText(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class ConversationHistoriesRole(TypedDict):
|
||||
user_prefix: str
|
||||
assistant_prefix: str
|
||||
|
||||
|
||||
class CompletionPromptConfig(TypedDict):
|
||||
prompt: CompletionPromptText
|
||||
conversation_histories_role: NotRequired[ConversationHistoriesRole]
|
||||
|
||||
|
||||
class ModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class AppModelConfigDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: EnabledConfig
|
||||
speech_to_text: EnabledConfig
|
||||
text_to_speech: EnabledConfig
|
||||
retriever_resource: EnabledConfig
|
||||
annotation_reply: AnnotationReplyConfig
|
||||
more_like_this: EnabledConfig
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
|
||||
external_data_tools: list[ExternalDataToolConfig]
|
||||
model: ModelConfig
|
||||
user_input_form: list[UserInputFormItem]
|
||||
dataset_query_variable: str | None
|
||||
pre_prompt: str | None
|
||||
agent_mode: AgentModeConfig
|
||||
prompt_type: str
|
||||
chat_prompt_config: ChatPromptConfig
|
||||
completion_prompt_config: CompletionPromptConfig
|
||||
dataset_configs: DatasetConfigs
|
||||
file_upload: FileUploadConfig
|
||||
# Added dynamically in Conversation.model_config
|
||||
model_id: NotRequired[str | None]
|
||||
provider: NotRequired[str | None]
|
||||
|
||||
|
||||
class ConversationDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
app_model_config_id: str | None
|
||||
model_provider: str | None
|
||||
override_model_configs: str | None
|
||||
model_id: str | None
|
||||
mode: str
|
||||
name: str
|
||||
summary: str | None
|
||||
inputs: dict[str, Any]
|
||||
introduction: str | None
|
||||
system_instruction: str | None
|
||||
system_instruction_tokens: int
|
||||
status: str
|
||||
invoke_from: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
read_at: datetime | None
|
||||
read_account_id: str | None
|
||||
dialogue_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
model_id: str | None
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
total_price: Decimal | None
|
||||
message: dict[str, Any]
|
||||
answer: str
|
||||
status: str
|
||||
error: str | None
|
||||
message_metadata: dict[str, Any]
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
agent_based: bool
|
||||
workflow_run_id: str | None
|
||||
|
||||
|
||||
class MessageFeedbackDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class MessageFileInfo(TypedDict, total=False):
|
||||
belongs_to: str | None
|
||||
upload_file_id: str | None
|
||||
id: str
|
||||
tenant_id: str
|
||||
type: str
|
||||
transfer_method: str
|
||||
remote_url: str | None
|
||||
related_id: str | None
|
||||
filename: str | None
|
||||
extension: str | None
|
||||
mime_type: str | None
|
||||
size: int
|
||||
dify_model_identity: str
|
||||
url: str | None
|
||||
|
||||
|
||||
class ExtraContentDict(TypedDict, total=False):
|
||||
type: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class TraceAppConfigDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
tracing_provider: str | None
|
||||
tracing_config: dict[str, Any]
|
||||
is_active: bool
|
||||
created_at: str | None
|
||||
updated_at: str | None
|
||||
|
||||
|
||||
class DifySetup(TypeBase):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
@@ -430,7 +176,7 @@ class App(Base):
|
||||
return str(self.mode)
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list[DeletedToolInfo]:
|
||||
def deleted_tools(self) -> list[dict[str, str]]:
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
@@ -511,7 +257,7 @@ class App(Base):
|
||||
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||
}
|
||||
|
||||
deleted_tools: list[DeletedToolInfo] = []
|
||||
deleted_tools: list[dict[str, str]] = []
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
@@ -618,38 +364,35 @@ class AppModelConfig(TypeBase):
|
||||
return app
|
||||
|
||||
@property
|
||||
def model_dict(self) -> ModelConfig:
|
||||
return cast(ModelConfig, json.loads(self.model) if self.model else {})
|
||||
def model_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.model) if self.model else {}
|
||||
|
||||
@property
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False},
|
||||
else {"enabled": False}
|
||||
)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
def speech_to_text_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
def text_to_speech_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
def retriever_resource_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
def annotation_reply_dict(self) -> dict[str, Any]:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
)
|
||||
@@ -672,62 +415,56 @@ class AppModelConfig(TypeBase):
|
||||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
def more_like_this_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
return cast(
|
||||
SensitiveWordAvoidanceConfig,
|
||||
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.sensitive_word_avoidance)
|
||||
if self.sensitive_word_avoidance
|
||||
else {"enabled": False, "type": "", "config": {}},
|
||||
else {"enabled": False, "type": "", "configs": []}
|
||||
)
|
||||
|
||||
@property
|
||||
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
|
||||
def external_data_tools_list(self) -> list[dict[str, Any]]:
|
||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||
|
||||
@property
|
||||
def user_input_form_list(self) -> list[UserInputFormItem]:
|
||||
def user_input_form_list(self) -> list[dict[str, Any]]:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> AgentModeConfig:
|
||||
return cast(
|
||||
AgentModeConfig,
|
||||
def agent_mode_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.agent_mode)
|
||||
if self.agent_mode
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> ChatPromptConfig:
|
||||
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
|
||||
def chat_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
|
||||
|
||||
@property
|
||||
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
|
||||
return cast(
|
||||
CompletionPromptConfig,
|
||||
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
|
||||
)
|
||||
def completion_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> DatasetConfigs:
|
||||
def dataset_configs_dict(self) -> dict[str, Any]:
|
||||
if self.dataset_configs:
|
||||
dataset_configs = json.loads(self.dataset_configs)
|
||||
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
|
||||
if "retrieval_model" not in dataset_configs:
|
||||
return {"retrieval_model": "single"}
|
||||
else:
|
||||
return cast(DatasetConfigs, dataset_configs)
|
||||
return dataset_configs
|
||||
return {
|
||||
"retrieval_model": "multiple",
|
||||
}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> FileUploadConfig:
|
||||
return cast(
|
||||
FileUploadConfig,
|
||||
def file_upload_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
json.loads(self.file_upload)
|
||||
if self.file_upload
|
||||
else {
|
||||
@@ -737,10 +474,10 @@ class AppModelConfig(TypeBase):
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> AppModelConfigDict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
@@ -764,42 +501,36 @@ class AppModelConfig(TypeBase):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
def from_model_config_dict(self, model_config: Mapping[str, Any]):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
json.dumps(model_config["suggested_questions_after_answer"])
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
|
||||
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
|
||||
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
json.dumps(model_config["sensitive_word_avoidance"])
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
|
||||
)
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.prompt_type = model_config.get("prompt_type", "simple")
|
||||
self.chat_prompt_config = (
|
||||
@@ -980,18 +711,6 @@ class Conversation(Base):
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||
sa.Index(
|
||||
"conversation_app_created_at_idx",
|
||||
"app_id",
|
||||
sa.text("created_at DESC"),
|
||||
postgresql_where=sa.text("is_deleted IS false"),
|
||||
),
|
||||
sa.Index(
|
||||
"conversation_app_updated_at_idx",
|
||||
"app_id",
|
||||
sa.text("updated_at DESC"),
|
||||
postgresql_where=sa.text("is_deleted IS false"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
@@ -1092,26 +811,24 @@ class Conversation(Base):
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def model_config(self) -> AppModelConfigDict:
|
||||
model_config = cast(AppModelConfigDict, {})
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
app_model_config: AppModelConfig | None = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
model_config = cast(AppModelConfigDict, override_model_configs)
|
||||
model_config = override_model_configs
|
||||
else:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
|
||||
if "model" in override_model_configs:
|
||||
# where is app_id?
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
|
||||
cast(AppModelConfigDict, override_model_configs)
|
||||
)
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
model_config["configs"] = override_model_configs
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
@@ -1286,7 +1003,7 @@ class Conversation(Base):
|
||||
def in_debug_mode(self) -> bool:
|
||||
return self.override_model_configs is not None
|
||||
|
||||
def to_dict(self) -> ConversationDict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@@ -1566,7 +1283,7 @@ class Message(Base):
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
def message_files(self) -> list[MessageFileInfo]:
|
||||
def message_files(self) -> list[dict[str, Any]]:
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
@@ -1621,13 +1338,10 @@ class Message(Base):
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result = cast(
|
||||
list[MessageFileInfo],
|
||||
[
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
],
|
||||
)
|
||||
result: list[dict[str, Any]] = [
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
|
||||
db.session.commit()
|
||||
return result
|
||||
@@ -1637,7 +1351,7 @@ class Message(Base):
|
||||
self._extra_contents = list(contents)
|
||||
|
||||
@property
|
||||
def extra_contents(self) -> list[ExtraContentDict]:
|
||||
def extra_contents(self) -> list[dict[str, Any]]:
|
||||
return getattr(self, "_extra_contents", [])
|
||||
|
||||
@property
|
||||
@@ -1653,7 +1367,7 @@ class Message(Base):
|
||||
|
||||
return None
|
||||
|
||||
def to_dict(self) -> MessageDict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@@ -1677,7 +1391,7 @@ class Message(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: MessageDict) -> Message:
|
||||
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||
return cls(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
@@ -1737,7 +1451,7 @@ class MessageFeedback(TypeBase):
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"app_id": str(self.app_id),
|
||||
@@ -2000,8 +1714,8 @@ class AppMCPServer(TypeBase):
|
||||
return result
|
||||
|
||||
@property
|
||||
def parameters_dict(self) -> dict[str, str]:
|
||||
return cast(dict[str, str], json.loads(self.parameters))
|
||||
def parameters_dict(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.parameters))
|
||||
|
||||
|
||||
class Site(Base):
|
||||
@@ -2441,7 +2155,7 @@ class TraceAppConfig(TypeBase):
|
||||
def tracing_config_str(self) -> str:
|
||||
return json.dumps(self.tracing_config_dict)
|
||||
|
||||
def to_dict(self) -> TraceAppConfigDict:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
||||
@@ -35,7 +35,7 @@ dependencies = [
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"markdown~=3.8.1",
|
||||
"markdown~=3.5.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
@@ -113,7 +113,7 @@ dev = [
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.38.2",
|
||||
"basedpyright~=1.31.0",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
"pytest-benchmark~=4.0.0",
|
||||
@@ -167,12 +167,12 @@ dev = [
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy~=1.19.1",
|
||||
"mypy~=1.17.1",
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.55.0",
|
||||
"pyrefly>=0.54.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
@@ -247,13 +247,3 @@ module = [
|
||||
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
[tool.pyrefly]
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
".venv",
|
||||
"migrations/",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
controllers/console/app/mcp_server.py
|
||||
controllers/console/app/site.py
|
||||
controllers/console/auth/email_register.py
|
||||
controllers/console/human_input_form.py
|
||||
controllers/console/init_validate.py
|
||||
controllers/console/ping.py
|
||||
controllers/console/setup.py
|
||||
controllers/console/version.py
|
||||
controllers/console/workspace/trigger_providers.py
|
||||
controllers/service_api/app/annotation.py
|
||||
controllers/web/workflow_events.py
|
||||
core/agent/fc_agent_runner.py
|
||||
core/app/apps/advanced_chat/app_generator.py
|
||||
core/app/apps/advanced_chat/app_runner.py
|
||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
||||
core/app/apps/agent_chat/app_generator.py
|
||||
core/app/apps/base_app_generate_response_converter.py
|
||||
core/app/apps/base_app_generator.py
|
||||
core/app/apps/chat/app_generator.py
|
||||
core/app/apps/common/workflow_response_converter.py
|
||||
core/app/apps/completion/app_generator.py
|
||||
core/app/apps/pipeline/pipeline_generator.py
|
||||
core/app/apps/pipeline/pipeline_runner.py
|
||||
core/app/apps/workflow/app_generator.py
|
||||
core/app/apps/workflow/app_runner.py
|
||||
core/app/apps/workflow/generate_task_pipeline.py
|
||||
core/app/apps/workflow_app_runner.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/datasource/datasource_manager.py
|
||||
core/external_data_tool/api/api.py
|
||||
core/llm_generator/llm_generator.py
|
||||
core/llm_generator/output_parser/structured_output.py
|
||||
core/mcp/mcp_client.py
|
||||
core/ops/aliyun_trace/data_exporter/traceclient.py
|
||||
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
|
||||
core/ops/mlflow_trace/mlflow_trace.py
|
||||
core/ops/ops_trace_manager.py
|
||||
core/ops/tencent_trace/client.py
|
||||
core/ops/tencent_trace/utils.py
|
||||
core/plugin/backwards_invocation/base.py
|
||||
core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/lindorm/lindorm_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/pdf_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/rag/retrieval/router/multi_dataset_function_call_router.py
|
||||
core/rag/summary_index/summary_index.py
|
||||
core/repositories/sqlalchemy_workflow_execution_repository.py
|
||||
core/repositories/sqlalchemy_workflow_node_execution_repository.py
|
||||
core/tools/__base/tool.py
|
||||
core/tools/mcp_tool/provider.py
|
||||
core/tools/plugin_tool/provider.py
|
||||
core/tools/utils/message_transformer.py
|
||||
core/tools/utils/web_reader_tool.py
|
||||
core/tools/workflow_as_tool/provider.py
|
||||
core/trigger/debug/event_selectors.py
|
||||
core/trigger/entities/entities.py
|
||||
core/trigger/provider.py
|
||||
core/workflow/workflow_entry.py
|
||||
dify_graph/entities/workflow_execution.py
|
||||
dify_graph/file/file_manager.py
|
||||
dify_graph/graph_engine/error_handler.py
|
||||
dify_graph/graph_engine/layers/execution_limits.py
|
||||
dify_graph/nodes/agent/agent_node.py
|
||||
dify_graph/nodes/base/node.py
|
||||
dify_graph/nodes/code/code_node.py
|
||||
dify_graph/nodes/datasource/datasource_node.py
|
||||
dify_graph/nodes/document_extractor/node.py
|
||||
dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
||||
dify_graph/nodes/start/start_node.py
|
||||
dify_graph/nodes/template_transform/template_transform_node.py
|
||||
dify_graph/nodes/tool/tool_node.py
|
||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
||||
dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
services/auth/firecrawl/firecrawl.py
|
||||
services/auth/jina.py
|
||||
services/auth/jina/jina.py
|
||||
services/auth/watercrawl/watercrawl.py
|
||||
services/conversation_service.py
|
||||
services/dataset_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
|
||||
services/external_knowledge_service.py
|
||||
services/plugin/plugin_migration.py
|
||||
services/recommend_app/buildin/buildin_retrieval.py
|
||||
services/recommend_app/database/database_retrieval.py
|
||||
services/recommend_app/remote/remote_retrieval.py
|
||||
services/summary_index_service.py
|
||||
services/tools/tools_transform_service.py
|
||||
services/trigger/trigger_provider_service.py
|
||||
services/trigger/trigger_subscription_builder_service.py
|
||||
services/trigger/webhook_service.py
|
||||
services/workflow_draft_variable_service.py
|
||||
services/workflow_event_snapshot_service.py
|
||||
services/workflow_service.py
|
||||
tasks/app_generate/workflow_execute_task.py
|
||||
tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
8
api/pyrefly.toml
Normal file
8
api/pyrefly.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
".venv",
|
||||
"migrations/",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
@@ -1,6 +1,5 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib
|
||||
addopts = --cov=./api --cov-report=json
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
@@ -20,7 +19,7 @@ env =
|
||||
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
|
||||
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
|
||||
MOCK_SWITCH = true
|
||||
|
||||
@@ -74,16 +74,6 @@ from tasks.mail_reset_password_task import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_join_enterprise_default_workspace(account_id: str) -> None:
|
||||
"""Best-effort join to enterprise default workspace."""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
|
||||
from services.enterprise.enterprise_service import try_join_default_workspace
|
||||
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
@@ -297,14 +287,13 @@ class AccountService:
|
||||
email=email, name=name, interface_language=interface_language, password=password
|
||||
)
|
||||
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
except Exception:
|
||||
# Enterprise-only side-effect should run independently from personal workspace creation.
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
raise
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
from services.enterprise.enterprise_service import try_join_default_workspace
|
||||
|
||||
try_join_default_workspace(str(account.id))
|
||||
|
||||
return account
|
||||
|
||||
@@ -1418,18 +1407,18 @@ class RegisterService:
|
||||
and create_workspace_required
|
||||
and FeatureService.get_system_features().license.workspaces.is_available()
|
||||
):
|
||||
try:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
except Exception:
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
raise
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
from services.enterprise.enterprise_service import try_join_default_workspace
|
||||
|
||||
try_join_default_workspace(str(account.id))
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
db.session.rollback()
|
||||
logger.exception("Register failed")
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -33,7 +32,7 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.model import AppModelConfig, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
@@ -524,7 +523,7 @@ class AppDslService:
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
).from_model_config_dict(model_config)
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
|
||||
@@ -38,13 +38,6 @@ if TYPE_CHECKING:
|
||||
class AppGenerateService:
|
||||
@staticmethod
|
||||
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
|
||||
"""
|
||||
Build a subscription callback that coordinates when the background task starts.
|
||||
|
||||
- streams transport: start immediately (events are durable; late subscribers can replay).
|
||||
- pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task
|
||||
still runs if the client never connects.
|
||||
"""
|
||||
started = False
|
||||
lock = threading.Lock()
|
||||
|
||||
@@ -61,18 +54,10 @@ class AppGenerateService:
|
||||
started = True
|
||||
return True
|
||||
|
||||
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
|
||||
if channel_type == "streams":
|
||||
# With Redis Streams, we can safely start right away; consumers can read past events.
|
||||
_try_start()
|
||||
|
||||
# Keep return type Callable[[], None] consistent while allowing an extra (no-op) call.
|
||||
def _on_subscribe_streams() -> None:
|
||||
_try_start()
|
||||
|
||||
return _on_subscribe_streams
|
||||
|
||||
# Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback.
|
||||
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
|
||||
# The Celery task may publish the first event before the API side actually subscribes,
|
||||
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
|
||||
# but also use a short fallback timer so the task still runs if the client never consumes.
|
||||
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
@@ -187,7 +187,7 @@ class AppService:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
@@ -388,7 +388,7 @@ class AppService:
|
||||
agent_config = app_model_config.agent_mode_dict
|
||||
|
||||
# get all tools
|
||||
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
|
||||
tools = agent_config.get("tools", [])
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
@@ -107,7 +106,7 @@ class AudioService:
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = cast(str | None, text_to_speech_dict.get("voice"))
|
||||
voice = text_to_speech_dict.get("voice")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
|
||||
@@ -130,7 +130,7 @@ class HumanInputService:
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._form_repository = form_repository or HumanInputFormSubmissionRepository()
|
||||
self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory)
|
||||
|
||||
def get_form_by_token(self, form_token: str) -> Form | None:
|
||||
record = self._form_repository.get_by_token(form_token)
|
||||
|
||||
@@ -63,12 +63,7 @@ class RagPipelineTransformService:
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if dataset.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("Unauthorized")
|
||||
node = self._deal_knowledge_index(
|
||||
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
|
||||
)
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
new_nodes.append(node)
|
||||
if new_nodes:
|
||||
graph["nodes"] = new_nodes
|
||||
@@ -160,13 +155,14 @@ class RagPipelineTransformService:
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user