Compare commits

..

5 Commits

Author SHA1 Message Date
CodingOnStar
ed9623647e Merge remote-tracking branch 'origin/main' into refactor/base-comp 2026-03-03 16:00:41 +08:00
CodingOnStar
aa5a22991b refactor(toast): streamline toast component structure and improve cleanup logic
- Adjusted class names for consistency in the Toast component.
- Refactored the toastHandler.clear function to improve cleanup logic by using a dedicated unmountAndRemove function.
- Ensured proper handling of the timer for toast notifications.
2026-03-02 11:54:51 +08:00
CodingOnStar
4928917878 Merge remote-tracking branch 'origin/main' into refactor/base-comp 2026-03-02 11:51:03 +08:00
CodingOnStar
b00afff61e fix(tests): correct import paths in chat and context block test files 2026-03-02 11:31:59 +08:00
CodingOnStar
691248f477 test: add unit tests for various components including Alert, AppUnavailable, Badge, ThemeSelector, ThemeSwitcher, ActionButton, and AgentLogModal 2026-03-02 11:11:08 +08:00
719 changed files with 14123 additions and 43838 deletions

View File

@@ -204,16 +204,6 @@ When assigned to test a directory/path, test **ALL content** within that path:
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
### `nuqs` Query State Testing (Required for URL State Hooks)
When a component or hook uses `useQueryState` / `useQueryStates`:
- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`)
- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`)
- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values)
- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable)
- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)

View File

@@ -80,9 +80,6 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`)
- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`)
- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values)
### Queries

View File

@@ -125,31 +125,6 @@ describe('Component', () => {
})
```
### 2.1 `nuqs` Query State (Preferred: Testing Adapter)
For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly.
```typescript
import { renderHookWithNuqs } from '@/test/nuqs-testing'
it('should sync query to URL with push history', async () => {
const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), {
searchParams: '?page=1',
})
act(() => {
result.current.setQuery({ page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.options.history).toBe('push')
expect(update.searchParams.get('page')).toBe('2')
})
```
Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope.
### 3. Portal Components (with Shared State)
```typescript

View File

@@ -1,100 +1,43 @@
---
name: orpc-contract-first
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
---
# oRPC Contract-First Development
## Intent
## Project Structure
- Keep contract as single source of truth in `web/contract/*`.
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
- Keep abstractions minimal and preserve TypeScript inference.
## Minimal Structure
```text
```
web/contract/
├── base.ts
├── router.ts
├── marketplace.ts
└── console/
├── billing.ts
└── ...other domains
web/service/client.ts
├── base.ts # Base contract (inputStructure: 'detailed')
├── router.ts # Router composition & type exports
├── marketplace.ts # Marketplace contracts
└── console/ # Console contracts by domain
├── system.ts
└── billing.ts
```
## Core Workflow
## Workflow
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
- Use `base.route({...}).output(type<...>())` as baseline.
- Add `.input(type<...>())` only when request has `params/query/body`.
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
2. Register contract in `web/contract/router.ts`
- Import directly from domain files and nest by API prefix.
3. Consume from UI call sites via oRPC query utils.
1. **Create contract** in `web/contract/console/{domain}.ts`
- Import `base` from `../base` and `type` from `@orpc/contract`
- Define route with `path`, `method`, `input`, `output`
```typescript
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
2. **Register in router** at `web/contract/router.ts`
- Import directly from domain file (no barrel files)
- Nest by API prefix: `billing: { invoices, bindPartnerStack }`
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
staleTime: 5 * 60 * 1000,
throwOnError: true,
select: invoice => invoice.url,
}))
```
3. **Create hooks** in `web/service/use-{domain}.ts`
- Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
- Use `consoleClient.{group}.{contract}()` for API calls
## Query Usage Decision Rule
1. Default: call site directly uses `*.queryOptions(...)`.
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration:
- Combine multiple queries/mutations.
- Share domain-level derived state or invalidation helpers.
```typescript
const invoicesBaseQueryOptions = () =>
consoleQuery.billing.invoices.queryOptions({ retry: false })
const invoiceQuery = useQuery({
...invoicesBaseQueryOptions(),
throwOnError: true,
})
```
## Mutation Usage Decision Rule
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
- `.key(...)`:
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
- `.queryKey(...)`:
- Use for a specific query's full key (exact query identity / direct cache addressing).
- `.mutationKey(...)`:
- Use for a specific mutation's full key.
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
## Contract Rules
## Key Rules
- **Input structure**: Always use `{ params, query?, body? }` format
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
- **Path params**: Use `{paramName}` in path, match in `params` object
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
- **Router nesting**: Group by API prefix (e.g., `/billing/*` `billing: {}`)
- **No barrel files**: Import directly from specific files
- **Types**: Import from `@/types/`, use `type<T>()` helper
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
## Type Export

View File

@@ -7,7 +7,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

View File

@@ -1,43 +1,25 @@
version: 2
multi-ecosystem-groups:
python:
schedule:
interval: "weekly" # or whatever schedule you want
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 2
patterns: ["*"]
schedule:
interval: "weekly"
groups:
python-dependencies:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 2
patterns: ["*"]
schedule:
interval: "weekly"
groups:
uv-dependencies:
patterns:
- "*"
- package-ecosystem: "npm"
directory: "/web"
schedule:
interval: "weekly"
open-pull-requests-limit: 2
groups:
lexical:
patterns:
- "lexical"
- "@lexical/*"
storybook:
patterns:
- "storybook"
- "@storybook/*"
npm-dependencies:
patterns:
- "*"
exclude-patterns:
- "lexical"
- "@lexical/*"
- "storybook"
- "@storybook/*"

View File

@@ -89,7 +89,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@@ -28,7 +28,7 @@ jobs:
- name: Use Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'

View File

@@ -57,7 +57,7 @@ jobs:
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@@ -39,7 +39,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -83,7 +83,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -457,7 +457,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

1
.gitignore vendored
View File

@@ -222,7 +222,6 @@ mise.toml
# AI Assistant
.roo/
/.claude/worktrees/
api/.env.backup
/clickzetta

View File

@@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"--loglevel",
"INFO"
],

View File

@@ -29,7 +29,7 @@ The codebase is split into:
## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
## General Practices

View File

@@ -68,9 +68,8 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
@echo "📝 Running type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@./dev/pyrefly-check-local
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
@@ -132,7 +131,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
@echo " make type-check - Run type checks (basedpyright, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@@ -42,8 +42,6 @@ REFRESH_TOKEN_EXPIRE_DAYS=30
# redis configuration
REDIS_HOST=localhost
REDIS_PORT=6379
# Optional: limit total connections in connection pool (unset for default)
# REDIS_MAX_CONNECTIONS=200
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false

View File

@@ -28,8 +28,17 @@ ignore_imports =
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events
dify_graph.nodes.loop.loop_node -> dify_graph.graph_events
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine.command_channels
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine
dify_graph.nodes.loop.loop_node -> dify_graph.graph
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine.command_channels
# TODO(QuantumGhost): fix the import violation later
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
@@ -49,6 +58,8 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
dify_graph.nodes.human_input.human_input_node -> extensions.ext_database
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
@@ -92,9 +103,13 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
dify_graph.nodes.agent.agent_node -> core.model_manager
dify_graph.nodes.agent.agent_node -> core.provider_manager
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
dify_graph.nodes.document_extractor.node -> core.helper.ssrf_proxy
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
@@ -138,7 +153,10 @@ ignore_imports =
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.human_input.human_input_node -> extensions.ext_database
dify_graph.nodes.human_input.human_input_node -> core.repositories.human_input_repository
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.agent.agent_node -> services
dify_graph.nodes.tool.tool_node -> services

View File

@@ -62,22 +62,6 @@ This is the default standard for backend code in this repo. Follow it for new co
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
```python
from datetime import datetime
from typing import NotRequired, TypedDict
class UserProfile(TypedDict):
user_id: str
email: str
created_at: datetime
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python

View File

@@ -30,7 +30,6 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
@@ -937,12 +936,6 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
is_flag=True,
help="Preview cleanup results without deleting any workflow run data.",
)
@click.option(
"--task-label",
default="daily",
show_default=True,
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
)
def clean_workflow_runs(
before_days: int,
batch_size: int,
@@ -951,13 +944,10 @@ def clean_workflow_runs(
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
task_label: str,
):
"""
Clean workflow runs and related workflow data for free tenants.
"""
from extensions.otel.runtime import flush_telemetry
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
@@ -977,17 +967,13 @@ def clean_workflow_runs(
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
try:
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
task_label=task_label,
).run()
finally:
flush_telemetry()
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
@@ -2612,29 +2598,15 @@ def migrate_oss(
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=False,
default=None,
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=False,
default=None,
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--from-days-ago",
type=int,
default=None,
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
)
@click.option(
"--before-days",
type=int,
default=None,
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
@@ -2643,99 +2615,33 @@ def migrate_oss(
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
@click.option(
"--task-label",
default="daily",
show_default=True,
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
)
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
from_days_ago: int | None,
before_days: int | None,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
task_label: str,
):
"""
Clean expired messages and related data for tenants based on clean policy.
"""
from extensions.otel.runtime import flush_telemetry
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
abs_mode = start_from is not None and end_before is not None
rel_mode = before_days is not None
if abs_mode and rel_mode:
raise click.UsageError(
"Options are mutually exclusive: use either (--start-from,--end-before) "
"or (--from-days-ago,--before-days)."
)
if from_days_ago is not None and before_days is None:
raise click.UsageError("--from-days-ago must be used together with --before-days.")
if (start_from is None) ^ (end_before is None):
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
if not abs_mode and not rel_mode:
raise click.UsageError(
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
)
if rel_mode:
assert before_days is not None
if before_days < 0:
raise click.UsageError("--before-days must be >= 0.")
if from_days_ago is not None:
if from_days_ago < 0:
raise click.UsageError("--from-days-ago must be >= 0.")
if from_days_ago <= before_days:
raise click.UsageError("--from-days-ago must be greater than --before-days.")
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
if abs_mode:
assert start_from is not None
assert end_before is not None
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
elif from_days_ago is None:
assert before_days is not None
service = MessagesCleanService.from_days(
policy=policy,
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
else:
assert before_days is not None
assert from_days_ago is not None
now = naive_utc_now()
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=now - datetime.timedelta(days=from_days_ago),
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
@@ -2760,81 +2666,5 @@ def clean_expired_messages(
)
)
raise
finally:
flush_telemetry()
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -111,8 +111,3 @@ class RedisConfig(BaseSettings):
description="Enable client side cache in redis",
default=False,
)
REDIS_MAX_CONNECTIONS: PositiveInt | None = Field(
description="Maximum connections in the Redis connection pool (unset for library default)",
default=None,
)

View File

@@ -1,7 +1,7 @@
from typing import Literal, Protocol
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -23,56 +23,41 @@ class RedisConfigDefaultsMixin:
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
"""
Configuration settings for event transport between API and workers.
Supported transports:
- pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once)
- sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling)
- streams: Redis Streams (at-least-once, supports late subscribers)
Configuration settings for Redis pub/sub streaming.
"""
PUBSUB_REDIS_URL: str | None = Field(
validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"),
alias="PUBSUB_REDIS_URL",
description=(
"Redis connection URL for streaming events between API and celery worker; "
"defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL."
"Redis connection URL for pub/sub streaming events between API "
"and celery worker, defaults to url constructed from "
"`REDIS_*` configurations"
),
default=None,
)
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
description=(
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
"recommended to enable this for large deployments."
),
default=False,
)
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field(
validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"),
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
description=(
"Event transport type. Options are:\n\n"
" - pubsub: normal Pub/Sub (at-most-once)\n"
" - sharded: sharded Pub/Sub (at-most-once)\n"
" - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n"
"Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n"
"Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n"
"the risk of data loss from Redis auto-eviction under memory pressure.\n"
"Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE."
"Pub/sub channel type for streaming events. "
"Valid options are:\n"
"\n"
" - pubsub: for normal Pub/Sub\n"
" - sharded: for sharded Pub/Sub\n"
"\n"
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
"for large deployments."
),
default="pubsub",
)
PUBSUB_STREAMS_RETENTION_SECONDS: int = Field(
validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"),
description=(
"When using 'streams', expire each stream key this many seconds after the last event is published. "
"Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS."
),
default=600,
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:

View File

@@ -1,5 +1,3 @@
from typing import Any, cast
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@@ -25,14 +23,14 @@ class AppParameterApi(InstalledAppResource):
if workflow is None:
raise AppUnavailableError()
features_dict: dict[str, Any] = workflow.features_dict
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -185,4 +185,4 @@ class AnnotationUpdateDeleteApi(Resource):
def delete(self, app_model: App, annotation_id: str):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return "", 204
return {"result": "success"}, 204

View File

@@ -1,5 +1,3 @@
from typing import Any, cast
from flask_restx import Resource
from controllers.common.fields import Parameters
@@ -35,14 +33,14 @@ class AppParameterApi(Resource):
if workflow is None:
raise AppUnavailableError()
features_dict: dict[str, Any] = workflow.features_dict
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -14,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
ConversationDelete,
ConversationInfiniteScrollPagination,
SimpleConversation,
)
@@ -162,7 +163,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return "", 204
return ConversationDelete(result="success").model_dump(mode="json"), 204
@service_api_ns.route("/conversations/<uuid:c_id>/name")

View File

@@ -132,8 +132,6 @@ class WorkflowRunDetailApi(Resource):
app_id=app_model.id,
run_id=workflow_run_id,
)
if not workflow_run:
raise NotFound("Workflow run not found.")
return workflow_run

View File

@@ -1,5 +1,4 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource
@@ -58,14 +57,14 @@ class AppParameterApi(WebApiResource):
if workflow is None:
raise AppUnavailableError()
features_dict: dict[str, Any] = workflow.features_dict
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -1,13 +1,10 @@
from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
@@ -15,7 +12,7 @@ class SensitiveWordAvoidanceConfigManager:
if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config", {}),
config=sensitive_word_avoidance_dict.get("config"),
)
else:
return None

View File

@@ -1,13 +1,10 @@
from typing import Any, cast
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
from models.model import AppModelConfigDict
class AgentConfigManager:
@classmethod
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
def convert(cls, config: dict) -> AgentEntity | None:
"""
Convert model config to model config
@@ -31,17 +28,17 @@ class AgentConfigManager:
agent_tools = []
for tool in agent_dict.get("tools", []):
tool_dict = cast(dict[str, Any], tool)
if len(tool_dict) >= 4:
if "enabled" not in tool_dict or not tool_dict["enabled"]:
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
"provider_type": tool_dict["provider_type"],
"provider_id": tool_dict["provider_id"],
"tool_name": tool_dict["tool_name"],
"tool_parameters": tool_dict.get("tool_parameters", {}),
"credential_id": tool_dict.get("credential_id", None),
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
@@ -50,8 +47,7 @@ class AgentConfigManager:
"react_router",
"router",
}:
agent_prompt_raw = agent_dict.get("prompt", None)
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
@@ -79,7 +75,7 @@ class AgentConfigManager:
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
max_iteration=agent_dict.get("max_iteration", 10),
)
return None

View File

@@ -1,5 +1,5 @@
import uuid
from typing import Any, Literal, cast
from typing import Literal, cast
from core.app.app_config.entities import (
DatasetEntity,
@@ -8,13 +8,13 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode, AppModelConfigDict
from models.model import AppMode
from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
def convert(cls, config: dict) -> DatasetEntity | None:
"""
Convert model config to model config
@@ -25,15 +25,11 @@ class DatasetConfigManager:
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
for dataset in datasets.get("datasets", []):
if not isinstance(dataset, dict):
continue
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != "dataset":
continue
dataset = dataset["dataset"]
if not isinstance(dataset, dict):
continue
if "enabled" not in dataset or not dataset["enabled"]:
continue
@@ -51,14 +47,15 @@ class DatasetConfigManager:
agent_dict = config.get("agent_mode", {})
for tool in agent_dict.get("tools", []):
if len(tool) == 1:
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != "dataset":
continue
tool_item = cast(dict[str, Any], tool)[key]
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue

View File

@@ -5,13 +5,12 @@ from core.app.app_config.entities import ModelConfigEntity
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
class ModelConfigManager:
@classmethod
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
def convert(cls, config: dict) -> ModelConfigEntity:
"""
Convert model config to model config
@@ -23,7 +22,7 @@ class ModelConfigManager:
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get("completion_params") or {}
completion_params = model_config.get("completion_params")
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]

View File

@@ -1,5 +1,3 @@
from typing import Any
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
@@ -8,12 +6,12 @@ from core.app.app_config.entities import (
)
from core.prompt.simple_prompt_transform import ModelMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode, AppModelConfigDict
from models.model import AppMode
class PromptTemplateConfigManager:
@classmethod
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
def convert(cls, config: dict) -> PromptTemplateEntity:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
@@ -42,15 +40,14 @@ class PromptTemplateConfigManager:
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params: dict[str, Any] = {
completion_prompt_template_params = {
"prompt": completion_prompt_config["prompt"]["text"],
}
conv_role = completion_prompt_config.get("conversation_histories_role")
if conv_role:
if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params["role_prefix"] = {
"user": conv_role["user_prefix"],
"assistant": conv_role["assistant_prefix"],
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(

View File

@@ -1,10 +1,8 @@
import re
from typing import cast
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
from models.model import AppModelConfigDict
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[
@@ -20,7 +18,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
class BasicVariablesConfigManager:
@classmethod
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
"""
Convert model config to model config
@@ -53,9 +51,7 @@ class BasicVariablesConfigManager:
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable["variable"],
type=variable.get("type", ""),
config=variable.get("config", {}),
variable=variable["variable"], type=variable["type"], config=variable["config"]
)
)
elif variable_type in {
@@ -68,10 +64,10 @@ class BasicVariablesConfigManager:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=cast(VariableEntityType, variable_type),
variable=variable["variable"],
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable["label"],
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],

View File

@@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
app_model_config_dict: dict[str, Any]
app_model_config_dict: dict
model: ModelConfigEntity
prompt_template: PromptTemplateEntity
dataset: DatasetEntity | None = None

View File

@@ -516,10 +516,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state,
)
yield from self._handle_advanced_chat_message_end_event(
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_partial_success_event(
self,
@@ -540,9 +538,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
exceptions_count=event.exceptions_count,
)
yield from self._handle_advanced_chat_message_end_event(
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
@@ -740,6 +735,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
@@ -859,14 +855,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
yield from self._handle_workflow_paused_event(event)
break
case QueueWorkflowSucceededEvent():
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPartialSuccessEvent():
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break

View File

@@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
from models.model import App, AppMode, AppModelConfig, Conversation
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
@@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: AppModelConfigDict | None = None,
override_config_dict: dict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
@@ -61,9 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(
@@ -72,7 +70,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=cast(dict[str, Any], config_dict),
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@@ -88,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
"""
Validate for agent chat app model config
@@ -159,7 +157,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return cast(AppModelConfigDict, filtered_config)
return filtered_config
@classmethod
def validate_agent_mode_and_set_defaults(

View File

@@ -1,5 +1,3 @@
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@@ -15,7 +13,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
from models.model import App, AppMode, AppModelConfig, Conversation
class ChatAppConfig(EasyUIBasedAppConfig):
@@ -33,7 +31,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: AppModelConfigDict | None = None,
override_config_dict: dict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
@@ -66,7 +64,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=cast(dict[str, Any], config_dict),
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@@ -81,7 +79,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for chat app model config
@@ -147,4 +145,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return cast(AppModelConfigDict, filtered_config)
return filtered_config

View File

@@ -173,10 +173,8 @@ class ChatAppRunner(AppRunner):
memory=memory,
message_id=message.id,
inputs=inputs,
vision_enabled=bool(
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
.get("image", {})
.get("enabled", False)
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []

View File

@@ -1,5 +1,3 @@
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@@ -10,7 +8,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
from models.model import App, AppMode, AppModelConfig
class CompletionAppConfig(EasyUIBasedAppConfig):
@@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
@@ -42,9 +40,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(
@@ -53,7 +49,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=cast(dict[str, Any], config_dict),
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@@ -68,7 +64,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for completion app model config
@@ -120,4 +116,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return cast(AppModelConfigDict, filtered_config)
return filtered_config

View File

@@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params", {})
completion_params = model_dict.get("completion_params")
completion_params["temperature"] = 0.9
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict

View File

@@ -132,10 +132,8 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
vision_enabled=bool(
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
.get("image", {})
.get("enabled", False)
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []

View File

@@ -8,14 +8,12 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
UserFrom,
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.enums import WorkflowType
from dify_graph.enums import UserFrom, WorkflowType
from dify_graph.graph import Graph
from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
@@ -258,15 +256,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
),
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=0,
)

View File

@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -33,6 +33,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import UserFrom
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import (
@@ -118,15 +119,13 @@ class WorkflowBasedAppRunner:
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
),
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=0,
)
@@ -268,15 +267,13 @@ class WorkflowBasedAppRunner:
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)

View File

@@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@@ -7,7 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import InvokeFrom
from dify_graph.file import File, FileUploadConfig
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
@@ -15,69 +14,6 @@ if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
class UserFrom(StrEnum):
ACCOUNT = "account"
END_USER = "end-user"
class InvokeFrom(StrEnum):
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class DifyRunContext(BaseModel):
tenant_id: str
app_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
def build_dify_run_context(
*,
tenant_id: str,
app_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
extra_context: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
"""
Build graph run_context with the reserved Dify runtime payload.
`extra_context` can carry user-defined context keys. The reserved `_dify`
payload is always overwritten by this function to keep one canonical source.
"""
run_context = dict(extra_context) if extra_context else {}
run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext(
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
)
return run_context
class ModelConfigWithCredentialsEntity(BaseModel):
"""
Model Config With Credentials Entity.

View File

@@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Any, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -44,13 +44,14 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.app.task_pipeline.message_file_utils import prepare_file_dict
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -218,14 +219,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
@@ -459,40 +460,91 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
# Fetch files associated with this message
files = None
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if message_files:
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
upload_file_ids = list(
dict.fromkeys(
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
)
)
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
files_list = []
for message_file in message_files:
file_dict = prepare_file_dict(message_file, upload_files_map)
files_list.append(file_dict)
files = files_list or None
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
files=files,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -1,6 +1,7 @@
import hashlib
import logging
from threading import Thread, Timer
import time
from threading import Thread
from typing import Union
from flask import Flask, current_app
@@ -95,9 +96,9 @@ class MessageCycleManager:
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
thread = Timer(
1,
self._generate_conversation_name_worker,
time.sleep(1)
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation_id,

View File

@@ -1,76 +0,0 @@
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.
:param message_file: MessageFile instance
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
:return: Dictionary containing file information
"""
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
return {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}

View File

@@ -75,9 +75,8 @@ class LLMQuotaLayer(GraphEngineLayer):
return
try:
dify_ctx = node.require_dify_context()
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)

View File

@@ -7,7 +7,7 @@ import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Final
from typing import Final, cast
from urllib.parse import urljoin
import httpx
@@ -201,7 +201,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
return cast(int, uuid_obj.int)
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e

View File

@@ -120,8 +120,7 @@ class TencentTraceClient:
# Metrics exporter and instruments
try:
from opentelemetry.sdk.metrics import Histogram as SdkHistogram
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
@@ -129,7 +128,7 @@ class TencentTraceClient:
use_http_json = protocol in {"http/json", "http-json"}
# Tencent APM works best with delta aggregation temporality
preferred_temporality: dict[type, AggregationTemporality] = {SdkHistogram: AggregationTemporality.DELTA}
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
def _create_metric_exporter(exporter_cls, **kwargs):
"""Create metric exporter with preferred_temporality support"""

View File

@@ -6,6 +6,7 @@ import hashlib
import random
import uuid
from datetime import datetime
from typing import cast
from opentelemetry.trace import Link, SpanContext, TraceFlags
@@ -22,7 +23,7 @@ class TencentTraceUtils:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
return uuid_obj.int
return cast(int, uuid_obj.int)
@staticmethod
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
@@ -51,9 +52,9 @@ class TencentTraceUtils:
@staticmethod
def create_link(trace_id_str: str) -> Link:
try:
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
except (ValueError, TypeError):
trace_id = uuid.uuid4().int
trace_id = cast(int, uuid.uuid4().int)
span_context = SpanContext(
trace_id=trace_id,

View File

@@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if workflow is None:
raise ValueError("unexpected app type")
features_dict: dict[str, Any] = workflow.features_dict
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = cast(dict[str, Any], app_model_config.to_dict())
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -65,7 +65,7 @@ class ChromaVector(BaseVector):
self._client.get_or_create_collection(collection_name)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
@@ -73,7 +73,6 @@ class ChromaVector(BaseVector):
collection = self._client.get_or_create_collection(self._collection_name)
# FIXME: chromadb using numpy array, fix the type error later
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
return uuids
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)

View File

@@ -605,36 +605,25 @@ class ClickzettaVector(BaseVector):
logger.warning("Failed to create inverted index: %s", e)
# Continue without inverted index - full-text search will fall back to LIKE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add documents with embeddings to the collection."""
if not documents:
return []
return
batch_size = self._config.batch_size
total_batches = (len(documents) + batch_size - 1) // batch_size
added_ids = []
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
batch_doc_ids = []
for doc in batch_docs:
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))))
added_ids.extend(batch_doc_ids)
# Execute batch insert through write queue
self._execute_write(
self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches
)
return added_ids
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
def _insert_batch(
self,
batch_docs: list[Document],
batch_embeddings: list[list[float]],
batch_doc_ids: list[str],
batch_index: int,
batch_size: int,
total_batches: int,
@@ -652,9 +641,14 @@ class ClickzettaVector(BaseVector):
data_rows = []
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids):
for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
metadata = doc.metadata or {}
if not isinstance(metadata, dict):
metadata = {}
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
# Fast path for JSON serialization
try:

View File

@@ -4,10 +4,9 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.db.session_factory import session_factory
from dify_graph.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryMethod,
@@ -199,9 +198,12 @@ class _InvalidTimeoutStatusError(ValueError):
class HumanInputFormRepositoryImpl:
def __init__(
self,
*,
session_factory: sessionmaker | Engine,
tenant_id: str,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._tenant_id = tenant_id
def _delivery_method_to_model(
@@ -215,7 +217,7 @@ class HumanInputFormRepositoryImpl:
id=delivery_id,
form_id=form_id,
delivery_method_type=delivery_method.type,
delivery_config_id=str(delivery_method.id),
delivery_config_id=delivery_method.id,
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
@@ -341,7 +343,7 @@ class HumanInputFormRepositoryImpl:
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
with session_factory.create_session() as session, session.begin():
with self._session_factory(expire_on_commit=False) as session, session.begin():
# Generate unique form ID
form_id = str(uuidv7())
start_time = naive_utc_now()
@@ -433,7 +435,7 @@ class HumanInputFormRepositoryImpl:
HumanInputForm.node_id == node_id,
HumanInputForm.tenant_id == self._tenant_id,
)
with session_factory.create_session() as session:
with self._session_factory(expire_on_commit=False) as session:
form_model: HumanInputForm | None = session.scalars(form_query).first()
if form_model is None:
return None
@@ -446,13 +448,18 @@ class HumanInputFormRepositoryImpl:
class HumanInputFormSubmissionRepository:
"""Repository for fetching and submitting human input forms."""
def __init__(self, session_factory: sessionmaker | Engine):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with session_factory.create_session() as session:
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
@@ -471,7 +478,7 @@ class HumanInputFormSubmissionRepository:
HumanInputFormRecipient.recipient_type == recipient_type,
)
)
with session_factory.create_session() as session:
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
@@ -487,7 +494,7 @@ class HumanInputFormSubmissionRepository:
submission_user_id: str | None,
submission_end_user_id: str | None,
) -> HumanInputFormRecord:
with session_factory.create_session() as session, session.begin():
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
@@ -517,7 +524,7 @@ class HumanInputFormSubmissionRepository:
timeout_status: HumanInputFormStatus,
reason: str | None = None,
) -> HumanInputFormRecord:
with session_factory.create_session() as session, session.begin():
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")

View File

@@ -194,13 +194,6 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Create a new database session
with self._session_factory() as session:
existing_model = session.get(WorkflowRun, db_model.id)
if existing_model:
if existing_model.tenant_id != self._tenant_id:
raise ValueError("Unauthorized access to workflow run")
# Preserve the original start time for pause/resume flows.
db_model.created_at = existing_model.created_at
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)

View File

@@ -37,7 +37,6 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
}

View File

@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import (
@@ -20,10 +19,8 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
from dify_graph.graph.graph import NodeFactory
@@ -37,7 +34,6 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.datasource import DatasourceNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from dify_graph.nodes.llm.entities import ModelConfig
@@ -112,7 +108,6 @@ class DifyNodeFactory(NodeFactory):
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._dify_context = self._resolve_dify_context(graph_init_params.run_context)
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
self._code_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
@@ -144,16 +139,7 @@ class DifyNodeFactory(NodeFactory):
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
@staticmethod
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, DifyRunContext):
return raw_ctx
return DifyRunContext.model_validate(raw_ctx)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@@ -219,15 +205,6 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.HUMAN_INPUT:
return HumanInputNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,
@@ -277,7 +254,6 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
http_client=self._http_request_http_client,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
@@ -368,7 +344,7 @@ class DifyNodeFactory(NodeFactory):
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self._dify_context.app_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,
)

View File

@@ -5,26 +5,26 @@ from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
from dify_graph.enums import UserFrom
from dify_graph.errors import WorkflowNodeRunFailedError
from dify_graph.file.models import File
from dify_graph.graph import Graph
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
from dify_graph.graph_engine.command_channels import InMemoryChannel
from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from dify_graph.nodes import NodeType
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from extensions.otel.runtime import is_instrument_flag_enabled
@@ -34,66 +34,6 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
class _WorkflowChildEngineBuilder:
@staticmethod
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
"""
Return whether `graph_config["nodes"]` contains the given node id.
Returns `None` when the nodes payload shape is unexpected, so graph-level
validation can surface the original configuration error.
"""
nodes = graph_config.get("nodes")
if not isinstance(nodes, list):
return None
for node in nodes:
if not isinstance(node, Mapping):
return None
current_id = node.get("id")
if isinstance(current_id, str) and current_id == node_id:
return True
return False
def build_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: Mapping[str, Any],
root_node_id: str,
layers: Sequence[object] = (),
) -> GraphEngine:
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id)
if has_root_node is False:
raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found")
child_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=root_node_id,
)
child_engine = GraphEngine(
workflow_id=workflow_id,
graph=child_graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
child_engine_builder=self,
)
child_engine.layer(LLMQuotaLayer())
for layer in layers:
child_engine.layer(cast(GraphEngineLayer, layer))
return child_engine
class WorkflowEntry:
def __init__(
self,
@@ -137,7 +77,6 @@ class WorkflowEntry:
command_channel = InMemoryChannel()
self.command_channel = command_channel
self._child_engine_builder = _WorkflowChildEngineBuilder()
self.graph_engine = GraphEngine(
workflow_id=workflow_id,
graph=graph,
@@ -149,7 +88,6 @@ class WorkflowEntry:
scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD,
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
),
child_engine_builder=self._child_engine_builder,
)
# Add debug logging layer when in debug mode
@@ -216,15 +154,13 @@ class WorkflowEntry:
# init graph init params and runtime state
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@@ -357,15 +293,13 @@ class WorkflowEntry:
# init graph init params and runtime state
graph_init_params = GraphInitParams(
tenant_id=tenant_id,
app_id="",
workflow_id="",
graph_config=graph_dict,
run_context=build_dify_run_context(
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())

View File

@@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel, Field
DIFY_RUN_CONTEXT_KEY = "_dify"
from dify_graph.enums import InvokeFrom, UserFrom
class GraphInitParams(BaseModel):
@@ -18,7 +18,11 @@ class GraphInitParams(BaseModel):
"""
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
run_context: Mapping[str, Any] = Field(..., description="runtime context")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@@ -33,6 +33,39 @@ class SystemVariableKey(StrEnum):
INVOKE_FROM = "invoke_from"
class UserFrom(StrEnum):
ACCOUNT = "account"
END_USER = "end-user"
class InvokeFrom(StrEnum):
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
"""Get source of invoke from.
:return: source
"""
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class NodeType(StrEnum):
START = "start"
END = "end"

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
import logging
import queue
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from dify_graph.context import capture_current_context
@@ -27,7 +27,6 @@ from dify_graph.graph_events import (
GraphRunSucceededEvent,
)
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from dify_graph.runtime.graph_runtime_state import GraphProtocol
@@ -50,7 +49,6 @@ from .protocols.command_channel import CommandChannel
from .worker_management import WorkerPool
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator
@@ -76,7 +74,6 @@ class GraphEngine:
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
config: GraphEngineConfig = _DEFAULT_CONFIG,
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
@@ -86,9 +83,6 @@ class GraphEngine:
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
self._child_engine_builder = child_engine_builder
if child_engine_builder is not None:
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
@@ -220,25 +214,6 @@ class GraphEngine:
self._bind_layer_context(layer)
return self
def create_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: dict[str, object] | Mapping[str, object],
root_node_id: str,
layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (),
) -> GraphEngine:
return self._graph_runtime_state.create_child_engine(
workflow_id=workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
graph_config=graph_config,
root_node_id=root_node_id,
layers=layers,
)
def run(self) -> Generator[GraphEngineEvent, None, None]:
"""
Execute the graph using the modular architecture.

View File

@@ -80,11 +80,9 @@ class AgentNode(Node[AgentNodeData]):
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = get_plugin_agent_strategy(
tenant_id=dify_ctx.tenant_id,
tenant_id=self.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
@@ -122,8 +120,8 @@ class AgentNode(Node[AgentNodeData]):
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
@@ -146,8 +144,8 @@ class AgentNode(Node[AgentNodeData]):
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
@@ -285,13 +283,8 @@ class AgentNode(Node[AgentNodeData]):
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
dify_ctx = self.require_dify_context()
tool_runtime = ToolManager.get_agent_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
entity,
dify_ctx.invoke_from,
runtime_variable_pool,
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
@@ -403,8 +396,7 @@ class AgentNode(Node[AgentNodeData]):
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
dify_ctx = self.require_dify_context()
plugins = manager.list_plugins(dify_ctx.tenant_id)
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
plugin
@@ -425,11 +417,8 @@ class AgentNode(Node[AgentNodeData]):
return None
conversation_id = conversation_id_variable.value
dify_ctx = self.require_dify_context()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
)
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
@@ -440,10 +429,9 @@ class AgentNode(Node[AgentNodeData]):
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
dify_ctx = self.require_dify_context()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
@@ -452,7 +440,7 @@ class AgentNode(Node[AgentNodeData]):
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=dify_ctx.tenant_id,
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,

View File

@@ -8,11 +8,10 @@ from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
NodeExecutionType,
@@ -65,28 +64,10 @@ from libs.datetime_utils import naive_utc_now
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
logger = logging.getLogger(__name__)
class DifyRunContextProtocol(Protocol):
tenant_id: str
app_id: str
user_id: str
user_from: Any
invoke_from: Any
class _MappingDifyRunContext:
def __init__(self, mapping: Mapping[str, Any]) -> None:
self.tenant_id = str(mapping["tenant_id"])
self.app_id = str(mapping["app_id"])
self.user_id = str(mapping["user_id"])
self.user_from = mapping["user_from"]
self.invoke_from = mapping["invoke_from"]
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
@@ -246,10 +227,14 @@ class Node(Generic[NodeDataT]):
graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self._run_context = MappingProxyType(dict(graph_init_params.run_context))
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
@@ -278,38 +263,6 @@ class Node(Generic[NodeDataT]):
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
def run_context(self) -> Mapping[str, Any]:
return self._run_context
def get_run_context_value(self, key: str, default: Any = None) -> Any:
return self._run_context.get(key, default)
def require_run_context_value(self, key: str) -> Any:
value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE)
if value is _MISSING_RUN_CONTEXT_VALUE:
raise ValueError(f"run_context missing required key: {key}")
return value
def require_dify_context(self) -> DifyRunContextProtocol:
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, Mapping):
missing_keys = [
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
]
if missing_keys:
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
return _MappingDifyRunContext(raw_ctx)
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
if not hasattr(raw_ctx, attr):
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
return cast(DifyRunContextProtocol, raw_ctx)
@property
def execution_id(self) -> str:
return self._node_execution_id

View File

@@ -52,7 +52,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
Run the datasource node
"""
dify_ctx = self.require_dify_context()
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
@@ -76,7 +75,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
datasource_info["icon"] = self.datasource_manager.get_icon_url(
provider_id=provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=dify_ctx.tenant_id,
tenant_id=self.tenant_id,
datasource_type=datasource_type.value,
)
@@ -105,11 +104,11 @@ class DatasourceNode(Node[DatasourceNodeData]):
yield from self.datasource_manager.stream_node_events(
node_id=self._node_id,
user_id=dify_ctx.user_id,
user_id=self.user_id,
datasource_name=node_data.datasource_name or "",
datasource_type=datasource_type.value,
provider_id=provider_id,
tenant_id=dify_ctx.tenant_id,
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=credential_id,
@@ -137,7 +136,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
raise DatasourceNodeError("File is not exist")
file_info = self.datasource_manager.get_upload_file_by_id(
file_id=related_id, tenant_id=dify_ctx.tenant_id
file_id=related_id, tenant_id=self.tenant_id
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())

View File

@@ -4,7 +4,6 @@ import json
import logging
import os
import tempfile
import zipfile
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
@@ -21,11 +20,11 @@ from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from core.helper import ssrf_proxy
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod, file_manager
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.variables import ArrayFileSegment
from dify_graph.variables.segments import ArrayStringSegment, FileSegment
@@ -59,7 +58,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
unstructured_api_config: UnstructuredApiConfig | None = None,
http_client: HttpClientProtocol,
) -> None:
super().__init__(
id=id,
@@ -68,7 +66,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
graph_runtime_state=graph_runtime_state,
)
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
self._http_client = http_client
def _run(self):
variable_selector = self.node_data.variable_selector
@@ -83,24 +80,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
value = variable.value
inputs = {"variable_selector": variable_selector}
if isinstance(value, list):
value = list(filter(lambda x: x, value))
process_data = {"documents": value if isinstance(value, list) else [value]}
if not value:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=[])},
)
try:
if isinstance(value, list):
extracted_text_list = [
_extract_text_from_file(
self._http_client, file, unstructured_api_config=self._unstructured_api_config
)
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
for file in value
]
return NodeRunResult(
@@ -110,9 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(
self._http_client, value, unstructured_api_config=self._unstructured_api_config
)
extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -122,7 +105,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@@ -397,32 +379,6 @@ def parser_docx_part(block, doc: Document, content_items, i):
content_items.append((i, "table", Table(block, doc)))
def _normalize_docx_zip(file_content: bytes) -> bytes:
"""
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
ZIP entry names use backslash (\\) as path separator instead of the forward
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
"word\\document.xml" is never found when python-docx looks for
"word/document.xml", which triggers a KeyError about a missing relationship.
This function rewrites the ZIP in-memory, normalizing all entry names to
use forward slashes without touching any actual document content.
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
out_buf = io.BytesIO()
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
for item in zin.infolist():
data = zin.read(item.filename)
# Normalize backslash path separators to forward slash
item.filename = item.filename.replace("\\", "/")
zout.writestr(item, data)
return out_buf.getvalue()
except zipfile.BadZipFile:
# Not a valid zip — return as-is and let python-docx report the real error
return file_content
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
@@ -430,15 +386,7 @@ def _extract_text_from_docx(file_content: bytes) -> str:
"""
try:
doc_file = io.BytesIO(file_content)
try:
doc = docx.Document(doc_file)
except Exception as e:
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
# Some DOCX files exported by tools like Evernote on Windows use
# backslash path separators in ZIP entries and/or single-quoted XML
# attributes, both of which break python-docx on Linux. Normalize and retry.
file_content = _normalize_docx_zip(file_content)
doc = docx.Document(io.BytesIO(file_content))
doc = docx.Document(doc_file)
text = []
# Keep track of paragraph and table positions
@@ -491,13 +439,13 @@ def _extract_text_from_docx(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes:
def _download_file_content(file: File) -> bytes:
"""Download the content of a file based on its transfer method."""
try:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
if file.remote_url is None:
raise FileDownloadError("Missing URL for remote file")
response = http_client.get(file.remote_url)
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
else:
@@ -506,10 +454,8 @@ def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(
http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig
) -> str:
file_content = _download_file_content(http_client, file)
def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
file_content = _download_file_content(file)
if file.extension:
extracted_text = _extract_text_by_file_extension(
file_content=file_content,

View File

@@ -212,7 +212,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
"""
Extract files from response by checking both Content-Type header and URL
"""
dify_ctx = self.require_dify_context()
files: list[File] = []
is_file = response.is_file
content_type = response.content_type
@@ -237,8 +236,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=mime_type,
@@ -250,7 +249,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=dify_ctx.tenant_id,
tenant_id=self.tenant_id,
)
files.append(file)

View File

@@ -3,8 +3,9 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
@@ -20,6 +21,7 @@ from dify_graph.repositories.human_input_form_repository import (
HumanInputFormRepository,
)
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
@@ -31,8 +33,6 @@ if TYPE_CHECKING:
_SELECTED_BRANCH_KEY = "selected_branch"
_INVOKE_FROM_DEBUGGER = "debugger"
_INVOKE_FROM_EXPLORE = "explore"
logger = logging.getLogger(__name__)
@@ -66,7 +66,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository,
form_repository: HumanInputFormRepository | None = None,
) -> None:
super().__init__(
id=id,
@@ -74,6 +74,11 @@ class HumanInputNode(Node[HumanInputNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if form_repository is None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self.tenant_id,
)
self._form_repository = form_repository
@classmethod
@@ -157,39 +162,30 @@ class HumanInputNode(Node[HumanInputNodeData]):
return resolved_defaults
def _should_require_console_recipient(self) -> bool:
invoke_from = self._invoke_from_value()
if invoke_from == _INVOKE_FROM_DEBUGGER:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
if invoke_from == _INVOKE_FROM_EXPLORE:
if self.invoke_from == InvokeFrom.EXPLORE:
return self._node_data.is_webapp_enabled()
return False
def _display_in_ui(self) -> bool:
if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
dify_ctx = self.require_dify_context()
invoke_from = self._invoke_from_value()
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
user_id=dify_ctx.user_id,
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
user_id=self.user_id or "",
)
for method in enabled_methods
]
def _invoke_from_value(self) -> str:
invoke_from = self.require_dify_context().invoke_from
if isinstance(invoke_from, str):
return invoke_from
return str(getattr(invoke_from, "value", invoke_from))
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
@@ -223,11 +219,10 @@ class HumanInputNode(Node[HumanInputNodeData]):
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
dify_ctx = self.require_dify_context()
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=dify_ctx.app_id,
app_id=self.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
@@ -237,9 +232,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
dify_ctx.user_id
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
else None
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
),
backstage_recipient_required=True,
)

View File

@@ -587,14 +587,24 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState
from dify_graph.graph import Graph
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
from dify_graph.graph_engine.command_channels import InMemoryChannel
from dify_graph.runtime import GraphRuntimeState
# Create GraphInitParams for child graph execution.
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
run_context=self.run_context,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
)
# Create a deep copy of the variable pool for each iteration
@@ -611,17 +621,28 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
total_tokens=0,
node_run_steps=0,
)
root_node_id = self.node_data.start_node_id
if root_node_id is None:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
try:
return self.graph_runtime_state.create_child_engine(
workflow_id=self.workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state_copy,
graph_config=self.graph_config,
root_node_id=root_node_id,
)
except ChildGraphNotFoundError as exc:
raise IterationGraphNotFoundError("iteration graph not found") from exc
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -3,7 +3,7 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.template import Template
@@ -20,7 +20,6 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_INVOKE_FROM_DEBUGGER = "debugger"
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
@@ -59,8 +58,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
invoke_from_value = str(invoke_from.value) if invoke_from else None
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
chunks = variable.value
variables = {"chunks": chunks}

View File

@@ -23,11 +23,7 @@ from dify_graph.variables import (
)
from dify_graph.variables.segments import ArrayObjectSegment
from .entities import (
Condition,
KnowledgeRetrievalNodeData,
MetadataFilteringCondition,
)
from .entities import KnowledgeRetrievalNodeData
from .exc import (
KnowledgeRetrievalNodeError,
RateLimitExceededError,
@@ -70,10 +66,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
@@ -120,7 +115,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -165,7 +160,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
dify_ctx = self.require_dify_context()
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
@@ -175,12 +169,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
resolved_metadata_conditions = (
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
if node_data.metadata_filtering_conditions
else None
)
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
@@ -188,10 +176,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
user_from=dify_ctx.user_from.value,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
@@ -199,7 +187,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
)
@@ -241,10 +229,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=dify_ctx.app_id,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
user_from=dify_ctx.user_from.value,
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
@@ -257,7 +245,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
@@ -266,48 +254,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
def _resolve_metadata_filtering_conditions(
self, conditions: MetadataFilteringCondition
) -> MetadataFilteringCondition:
if conditions.conditions is None:
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator,
conditions=None,
)
variable_pool = self.graph_runtime_state.variable_pool
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,
comparison_operator=cond.comparison_operator,
value=resolved_value,
)
)
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator or "and",
conditions=resolved_conditions,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -145,10 +145,9 @@ class LLMNode(Node[LLMNodeData]):
self._memory = memory
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
@@ -243,7 +242,7 @@ class LLMNode(Node[LLMNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
@@ -703,7 +702,7 @@ class LLMNode(Node[LLMNodeData]):
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.require_dify_context().tenant_id,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,

View File

@@ -412,14 +412,24 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return build_segment_with_type(var_type, value)
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities import GraphInitParams
from dify_graph.graph import Graph
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
from dify_graph.graph_engine.command_channels import InMemoryChannel
from dify_graph.runtime import GraphRuntimeState
# Create GraphInitParams for child graph execution.
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
run_context=self.run_context,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
)
@@ -429,10 +439,22 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
start_at=start_at.timestamp(),
)
return self.graph_runtime_state.create_child_engine(
workflow_id=self.workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state_copy,
graph_config=self.graph_config,
root_node_id=root_node_id,
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the loop graph with the new node factory
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
tools=tools,
stop=list(stop),
stream=False,
user=self.require_dify_context().user_id,
user=self.user_id,
)
# handle invoke result

View File

@@ -86,10 +86,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._memory = memory
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
@@ -161,7 +160,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
user_id=self.user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,

View File

@@ -56,8 +56,6 @@ class ToolNode(Node[ToolNodeData]):
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
dify_ctx = self.require_dify_context()
# fetch tool icon
tool_info = {
"provider_type": self.node_data.provider_type.value,
@@ -77,12 +75,7 @@ class ToolNode(Node[ToolNodeData]):
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
self._node_id,
self.node_data,
dify_ctx.invoke_from,
variable_pool,
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@@ -116,10 +109,10 @@ class ToolNode(Node[ToolNodeData]):
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=dify_ctx.user_id,
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=dify_ctx.app_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
@@ -140,8 +133,8 @@ class ToolNode(Node[ToolNodeData]):
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)

View File

@@ -69,7 +69,6 @@ class TriggerWebhookNode(Node[WebhookData]):
)
def generate_file_var(self, param_name: str, file: dict):
dify_ctx = self.require_dify_context()
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
@@ -85,7 +84,7 @@ class TriggerWebhookNode(Node[WebhookData]):
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=dify_ctx.tenant_id,
tenant_id=self.tenant_id,
)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])

View File

@@ -1,17 +1,9 @@
from .graph_runtime_state import (
ChildEngineBuilderNotConfiguredError,
ChildEngineError,
ChildGraphNotFoundError,
GraphRuntimeState,
)
from .graph_runtime_state import GraphRuntimeState
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
from .variable_pool import VariablePool, VariableValue
__all__ = [
"ChildEngineBuilderNotConfiguredError",
"ChildEngineError",
"ChildGraphNotFoundError",
"GraphRuntimeState",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",

View File

@@ -15,7 +15,6 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.entities.pause_reason import PauseReason
@@ -136,31 +135,6 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class ChildGraphEngineBuilderProtocol(Protocol):
def build_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: Mapping[str, Any],
root_node_id: str,
layers: Sequence[object] = (),
) -> Any: ...
class ChildEngineError(ValueError):
"""Base error type for child-engine creation failures."""
class ChildEngineBuilderNotConfiguredError(ChildEngineError):
"""Raised when child-engine creation is requested without a bound builder."""
class ChildGraphNotFoundError(ChildEngineError):
"""Raised when the requested child graph entry point cannot be resolved."""
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
@@ -235,7 +209,6 @@ class GraphRuntimeState:
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None
# Node and edges states needed to be restored into
# graph object.
@@ -277,31 +250,6 @@ class GraphRuntimeState:
if self._graph is not None:
_ = self.response_coordinator
def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None:
self._child_engine_builder = builder
def create_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
graph_config: Mapping[str, Any],
root_node_id: str,
layers: Sequence[object] = (),
) -> Any:
if self._child_engine_builder is None:
raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.")
return self._child_engine_builder.build_child_engine(
workflow_id=workflow_id,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
graph_config=graph_config,
root_node_id=root_node_id,
layers=layers,
)
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------

View File

@@ -65,15 +65,9 @@ class VariablePool(BaseModel):
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool. When restoring from a serialized
# snapshot, `variable_dictionary` already carries the latest runtime values.
# In that case, keep existing entries instead of overwriting them with the
# bootstrap list.
# Add conversation variables to the variable pool
for var in self.conversation_variables:
selector = (CONVERSATION_VARIABLE_NODE_ID, var.name)
if self._has(selector):
continue
self.add(selector, var)
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)

View File

@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@@ -1,5 +1,3 @@
from typing import Any, cast
from sqlalchemy import select
from events.app_event import app_model_config_was_updated
@@ -56,11 +54,9 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s
continue
tool_type = list(tool.keys())[0]
tool_config = cast(dict[str, Any], list(tool.values())[0])
tool_config = list(tool.values())[0]
if tool_type == "dataset":
dataset_id = tool_config.get("id")
if isinstance(dataset_id, str):
dataset_ids.add(dataset_id)
dataset_ids.add(tool_config.get("id"))
# get dataset from dataset_configs
dataset_configs = app_model_config.dataset_configs_dict

View File

@@ -13,7 +13,6 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -67,7 +66,6 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -18,7 +18,6 @@ from dify_app import DifyApp
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
@@ -182,18 +181,13 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")]
sentinel_kwargs = {
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
"username": dify_config.REDIS_SENTINEL_USERNAME,
"password": dify_config.REDIS_SENTINEL_PASSWORD,
}
if dify_config.REDIS_MAX_CONNECTIONS:
sentinel_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
sentinel = Sentinel(
sentinel_hosts,
sentinel_kwargs=sentinel_kwargs,
sentinel_kwargs={
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
"username": dify_config.REDIS_SENTINEL_USERNAME,
"password": dify_config.REDIS_SENTINEL_PASSWORD,
},
)
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
@@ -210,15 +204,12 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
for node in dify_config.REDIS_CLUSTERS.split(",")
]
cluster_kwargs: dict[str, Any] = {
"startup_nodes": nodes,
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
}
if dify_config.REDIS_MAX_CONNECTIONS:
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
cluster: RedisCluster = RedisCluster(**cluster_kwargs)
cluster: RedisCluster = RedisCluster(
startup_nodes=nodes,
password=dify_config.REDIS_CLUSTERS_PASSWORD,
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
cache_config=_get_cache_configuration(),
)
return cluster
@@ -234,9 +225,6 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
}
)
if dify_config.REDIS_MAX_CONNECTIONS:
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
if ssl_kwargs:
redis_params.update(ssl_kwargs)
@@ -246,17 +234,9 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
max_conns = dify_config.REDIS_MAX_CONNECTIONS
if use_clusters:
if max_conns:
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
else:
return RedisCluster.from_url(pubsub_url)
if max_conns:
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
else:
return redis.Redis.from_url(pubsub_url)
return RedisCluster.from_url(pubsub_url)
return redis.Redis.from_url(pubsub_url)
def init_app(app: DifyApp):
@@ -289,11 +269,6 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
return StreamsBroadcastChannel(
_pubsub_redis_client,
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
)
return RedisBroadcastChannel(_pubsub_redis_client)

View File

@@ -5,7 +5,7 @@ from typing import Union
from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in
from opentelemetry import metrics, trace
from opentelemetry import trace
from opentelemetry.propagate import set_global_textmap
from opentelemetry.propagators.b3 import B3Format
from opentelemetry.propagators.composite import CompositePropagator
@@ -31,29 +31,9 @@ def setup_context_propagation() -> None:
def shutdown_tracer() -> None:
flush_telemetry()
def flush_telemetry() -> None:
"""
Best-effort flush for telemetry providers.
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
so counters/histograms are exported before the process exits.
"""
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
try:
provider.force_flush()
except Exception:
logger.exception("otel: failed to flush trace provider")
metric_provider = metrics.get_meter_provider()
if hasattr(metric_provider, "force_flush"):
try:
metric_provider.force_flush()
except Exception:
logger.exception("otel: failed to flush metric provider")
provider.force_flush()
def is_celery_worker():

View File

@@ -1,159 +0,0 @@
from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
logger = logging.getLogger(__name__)
class StreamsBroadcastChannel:
"""
Redis Streams based broadcast channel implementation.
Characteristics:
- At-least-once delivery for late subscribers within the stream retention window.
- Each topic is stored as a dedicated Redis Stream key.
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
"""
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
self._client = redis_client
self._retention_seconds = max(int(retention_seconds or 0), 0)
def topic(self, topic: str) -> StreamsTopic:
return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds)
class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
self._client = redis_client
self._topic = topic
self._key = f"stream:{topic}"
self._retention_seconds = retention_seconds
self.max_length = 5000
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length)
if self._retention_seconds > 0:
try:
self._client.expire(self._key, self._retention_seconds)
except Exception as e:
logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _StreamsSubscription(self._client, self._key)
class _StreamsSubscription(Subscription):
_SENTINEL = object()
def __init__(self, client: Redis | RedisCluster, key: str):
self._client = client
self._key = key
self._closed = threading.Event()
self._last_id = "0-0"
self._queue: queue.Queue[object] = queue.Queue()
self._start_lock = threading.Lock()
self._listener: threading.Thread | None = None
def _listen(self) -> None:
try:
while not self._closed.is_set():
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
if not streams:
continue
for _key, entries in streams:
for entry_id, fields in entries:
data = None
if isinstance(fields, dict):
data = fields.get(b"data")
data_bytes: bytes | None = None
if isinstance(data, str):
data_bytes = data.encode()
elif isinstance(data, (bytes, bytearray)):
data_bytes = bytes(data)
if data_bytes is not None:
self._queue.put_nowait(data_bytes)
self._last_id = entry_id
finally:
self._queue.put_nowait(self._SENTINEL)
self._listener = None
def _start_if_needed(self) -> None:
if self._listener is not None:
return
# Ensure only one listener thread is created under concurrent calls
with self._start_lock:
if self._listener is not None or self._closed.is_set():
return
self._listener = threading.Thread(
target=self._listen,
name=f"redis-streams-sub-{self._key}",
daemon=True,
)
self._listener.start()
def __iter__(self) -> Iterator[bytes]:
# Iterator delegates to receive with timeout; stops on closure.
self._start_if_needed()
while not self._closed.is_set():
item = self.receive(timeout=1)
if item is not None:
yield item
def receive(self, timeout: float | None = 0.1) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis streams subscription is closed")
self._start_if_needed()
try:
if timeout is None:
item = self._queue.get()
else:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
if item is self._SENTINEL or self._closed.is_set():
raise SubscriptionClosedError("The Redis streams subscription is closed")
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
return bytes(item)
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
listener = self._listener
if listener is not None:
listener.join(timeout=2.0)
if listener.is_alive():
logger.warning(
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
self._key,
)
else:
self._listener = None
# Context manager helpers
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:
self.close()
return None

View File

@@ -66,7 +66,6 @@ def run_migrations_offline():
context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True
)
logger.info("Generating offline migration SQL with url: %s", url)
with context.begin_transaction():
context.run_migrations()

View File

@@ -1,37 +0,0 @@
"""add partial indexes on conversations for app_id with created_at and updated_at
Revision ID: e288952f2994
Revises: fce013ca180e
Create Date: 2026-02-26 13:36:45.928922
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e288952f2994'
down_revision = 'fce013ca180e'
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.create_index(
'conversation_app_created_at_idx',
['app_id', sa.literal_column('created_at DESC')],
unique=False,
postgresql_where=sa.text('is_deleted IS false'),
)
batch_op.create_index(
'conversation_app_updated_at_idx',
['app_id', sa.literal_column('updated_at DESC')],
unique=False,
postgresql_where=sa.text('is_deleted IS false'),
)
def downgrade():
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_index('conversation_app_updated_at_idx')
batch_op.drop_index('conversation_app_created_at_idx')

View File

@@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -15,7 +15,6 @@ from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
@@ -37,259 +36,6 @@ if TYPE_CHECKING:
from .workflow import Workflow
# --- TypedDict definitions for structured dict return types ---
class EnabledConfig(TypedDict):
enabled: bool
class EmbeddingModelInfo(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationReplyDisabledConfig(TypedDict):
enabled: Literal[False]
class AnnotationReplyEnabledConfig(TypedDict):
id: str
enabled: Literal[True]
score_threshold: float
embedding_model: EmbeddingModelInfo
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
class SensitiveWordAvoidanceConfig(TypedDict):
enabled: bool
type: str
config: dict[str, Any]
class AgentToolConfig(TypedDict):
provider_type: str
provider_id: str
tool_name: str
tool_parameters: dict[str, Any]
plugin_unique_identifier: NotRequired[str | None]
credential_id: NotRequired[str | None]
class AgentModeConfig(TypedDict):
enabled: bool
strategy: str | None
tools: list[AgentToolConfig | dict[str, Any]]
prompt: str | None
class ImageUploadConfig(TypedDict):
enabled: bool
number_limits: int
detail: str
transfer_methods: list[str]
class FileUploadConfig(TypedDict):
image: ImageUploadConfig
class DeletedToolInfo(TypedDict):
type: str
tool_name: str
provider_id: str
class ExternalDataToolConfig(TypedDict):
enabled: bool
variable: str
type: str
config: dict[str, Any]
class UserInputFormItemConfig(TypedDict):
variable: str
label: str
description: NotRequired[str]
required: NotRequired[bool]
max_length: NotRequired[int]
options: NotRequired[list[str]]
default: NotRequired[str]
type: NotRequired[str]
config: NotRequired[dict[str, Any]]
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
UserInputFormItem = dict[str, UserInputFormItemConfig]
class DatasetConfigs(TypedDict):
retrieval_model: str
datasets: NotRequired[dict[str, Any]]
top_k: NotRequired[int]
score_threshold: NotRequired[float]
score_threshold_enabled: NotRequired[bool]
reranking_model: NotRequired[dict[str, Any] | None]
weights: NotRequired[dict[str, Any] | None]
reranking_enabled: NotRequired[bool]
reranking_mode: NotRequired[str]
metadata_filtering_mode: NotRequired[str]
metadata_model_config: NotRequired[dict[str, Any] | None]
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
class ChatPromptMessage(TypedDict):
text: str
role: str
class ChatPromptConfig(TypedDict, total=False):
prompt: list[ChatPromptMessage]
class CompletionPromptText(TypedDict):
text: str
class ConversationHistoriesRole(TypedDict):
user_prefix: str
assistant_prefix: str
class CompletionPromptConfig(TypedDict):
prompt: CompletionPromptText
conversation_histories_role: NotRequired[ConversationHistoriesRole]
class ModelConfig(TypedDict):
provider: str
name: str
mode: str
completion_params: NotRequired[dict[str, Any]]
class AppModelConfigDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: EnabledConfig
speech_to_text: EnabledConfig
text_to_speech: EnabledConfig
retriever_resource: EnabledConfig
annotation_reply: AnnotationReplyConfig
more_like_this: EnabledConfig
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
external_data_tools: list[ExternalDataToolConfig]
model: ModelConfig
user_input_form: list[UserInputFormItem]
dataset_query_variable: str | None
pre_prompt: str | None
agent_mode: AgentModeConfig
prompt_type: str
chat_prompt_config: ChatPromptConfig
completion_prompt_config: CompletionPromptConfig
dataset_configs: DatasetConfigs
file_upload: FileUploadConfig
# Added dynamically in Conversation.model_config
model_id: NotRequired[str | None]
provider: NotRequired[str | None]
class ConversationDict(TypedDict):
id: str
app_id: str
app_model_config_id: str | None
model_provider: str | None
override_model_configs: str | None
model_id: str | None
mode: str
name: str
summary: str | None
inputs: dict[str, Any]
introduction: str | None
system_instruction: str | None
system_instruction_tokens: int
status: str
invoke_from: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
read_at: datetime | None
read_account_id: str | None
dialogue_count: int
created_at: datetime
updated_at: datetime
class MessageDict(TypedDict):
id: str
app_id: str
conversation_id: str
model_id: str | None
inputs: dict[str, Any]
query: str
total_price: Decimal | None
message: dict[str, Any]
answer: str
status: str
error: str | None
message_metadata: dict[str, Any]
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
agent_based: bool
workflow_run_id: str | None
class MessageFeedbackDict(TypedDict):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
class MessageFileInfo(TypedDict, total=False):
belongs_to: str | None
upload_file_id: str | None
id: str
tenant_id: str
type: str
transfer_method: str
remote_url: str | None
related_id: str | None
filename: str | None
extension: str | None
mime_type: str | None
size: int
dify_model_identity: str
url: str | None
class ExtraContentDict(TypedDict, total=False):
type: str
workflow_run_id: str
class TraceAppConfigDict(TypedDict):
id: str
app_id: str
tracing_provider: str | None
tracing_config: dict[str, Any]
is_active: bool
created_at: str | None
updated_at: str | None
class DifySetup(TypeBase):
__tablename__ = "dify_setups"
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@@ -430,7 +176,7 @@ class App(Base):
return str(self.mode)
@property
def deleted_tools(self) -> list[DeletedToolInfo]:
def deleted_tools(self) -> list[dict[str, str]]:
from core.tools.tool_manager import ToolManager, ToolProviderType
from services.plugin.plugin_service import PluginService
@@ -511,7 +257,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools: list[DeletedToolInfo] = []
deleted_tools: list[dict[str, str]] = []
for tool in tools:
keys = list(tool.keys())
@@ -618,38 +364,35 @@ class AppModelConfig(TypeBase):
return app
@property
def model_dict(self) -> ModelConfig:
return cast(ModelConfig, json.loads(self.model) if self.model else {})
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {}
@property
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
@property
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return (
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False},
else {"enabled": False}
)
@property
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
@property
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
@property
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
def annotation_reply_dict(self) -> dict[str, Any]:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
@@ -672,62 +415,56 @@ class AppModelConfig(TypeBase):
return {"enabled": False}
@property
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
return cast(
SensitiveWordAvoidanceConfig,
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return (
json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance
else {"enabled": False, "type": "", "config": {}},
else {"enabled": False, "type": "", "configs": []}
)
@property
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
def external_data_tools_list(self) -> list[dict[str, Any]]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self) -> list[UserInputFormItem]:
def user_input_form_list(self) -> list[dict[str, Any]]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
def agent_mode_dict(self) -> AgentModeConfig:
return cast(
AgentModeConfig,
def agent_mode_dict(self) -> dict[str, Any]:
return (
json.loads(self.agent_mode)
if self.agent_mode
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
)
@property
def chat_prompt_config_dict(self) -> ChatPromptConfig:
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
return cast(
CompletionPromptConfig,
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
)
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property
def dataset_configs_dict(self) -> DatasetConfigs:
def dataset_configs_dict(self) -> dict[str, Any]:
if self.dataset_configs:
dataset_configs = json.loads(self.dataset_configs)
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
return cast(DatasetConfigs, dataset_configs)
return dataset_configs
return {
"retrieval_model": "multiple",
}
@property
def file_upload_dict(self) -> FileUploadConfig:
return cast(
FileUploadConfig,
def file_upload_dict(self) -> dict[str, Any]:
return (
json.loads(self.file_upload)
if self.file_upload
else {
@@ -737,10 +474,10 @@ class AppModelConfig(TypeBase):
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
}
)
def to_dict(self) -> AppModelConfigDict:
def to_dict(self) -> dict[str, Any]:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@@ -764,42 +501,36 @@ class AppModelConfig(TypeBase):
"file_upload": self.file_upload_dict,
}
def from_model_config_dict(self, model_config: AppModelConfigDict):
def from_model_config_dict(self, model_config: Mapping[str, Any]):
self.opening_statement = model_config.get("opening_statement")
self.suggested_questions = (
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
)
self.suggested_questions_after_answer = (
json.dumps(model_config.get("suggested_questions_after_answer"))
json.dumps(model_config["suggested_questions_after_answer"])
if model_config.get("suggested_questions_after_answer")
else None
)
self.speech_to_text = (
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
)
self.text_to_speech = (
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
)
self.more_like_this = (
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
)
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
self.sensitive_word_avoidance = (
json.dumps(model_config.get("sensitive_word_avoidance"))
json.dumps(model_config["sensitive_word_avoidance"])
if model_config.get("sensitive_word_avoidance")
else None
)
self.external_data_tools = (
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
)
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
self.user_input_form = (
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
)
self.dataset_query_variable = model_config.get("dataset_query_variable")
self.pre_prompt = model_config.get("pre_prompt")
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
self.pre_prompt = model_config["pre_prompt"]
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
self.retriever_resource = (
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
)
self.prompt_type = model_config.get("prompt_type", "simple")
self.chat_prompt_config = (
@@ -980,18 +711,6 @@ class Conversation(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
sa.Index(
"conversation_app_created_at_idx",
"app_id",
sa.text("created_at DESC"),
postgresql_where=sa.text("is_deleted IS false"),
),
sa.Index(
"conversation_app_updated_at_idx",
"app_id",
sa.text("updated_at DESC"),
postgresql_where=sa.text("is_deleted IS false"),
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -1092,26 +811,24 @@ class Conversation(Base):
self._inputs = inputs
@property
def model_config(self) -> AppModelConfigDict:
model_config = cast(AppModelConfigDict, {})
def model_config(self):
model_config = {}
app_model_config: AppModelConfig | None = None
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = cast(AppModelConfigDict, override_model_configs)
model_config = override_model_configs
else:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
# where is app_id?
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
cast(AppModelConfigDict, override_model_configs)
)
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
model_config["configs"] = override_model_configs
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@@ -1286,7 +1003,7 @@ class Conversation(Base):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
def to_dict(self) -> ConversationDict:
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1566,7 +1283,7 @@ class Message(Base):
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self) -> list[MessageFileInfo]:
def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
@@ -1621,13 +1338,10 @@ class Message(Base):
)
files.append(file)
result = cast(
list[MessageFileInfo],
[
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
],
)
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
db.session.commit()
return result
@@ -1637,7 +1351,7 @@ class Message(Base):
self._extra_contents = list(contents)
@property
def extra_contents(self) -> list[ExtraContentDict]:
def extra_contents(self) -> list[dict[str, Any]]:
return getattr(self, "_extra_contents", [])
@property
@@ -1653,7 +1367,7 @@ class Message(Base):
return None
def to_dict(self) -> MessageDict:
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1677,7 +1391,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: MessageDict) -> Message:
def from_dict(cls, data: dict[str, Any]) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],
@@ -1737,7 +1451,7 @@ class MessageFeedback(TypeBase):
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self) -> MessageFeedbackDict:
def to_dict(self) -> dict[str, Any]:
return {
"id": str(self.id),
"app_id": str(self.app_id),
@@ -2000,8 +1714,8 @@ class AppMCPServer(TypeBase):
return result
@property
def parameters_dict(self) -> dict[str, str]:
return cast(dict[str, str], json.loads(self.parameters))
def parameters_dict(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.parameters))
class Site(Base):
@@ -2441,7 +2155,7 @@ class TraceAppConfig(TypeBase):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict)
def to_dict(self) -> TraceAppConfigDict:
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,

View File

@@ -35,7 +35,7 @@ dependencies = [
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"markdown~=3.8.1",
"markdown~=3.5.1",
"mlflow-skinny>=3.0.0",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
@@ -113,7 +113,7 @@ dev = [
"dotenv-linter~=0.5.0",
"faker~=38.2.0",
"lxml-stubs~=0.5.1",
"basedpyright~=1.38.2",
"basedpyright~=1.31.0",
"ruff~=0.14.0",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
@@ -167,12 +167,12 @@ dev = [
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy~=1.19.1",
"mypy~=1.17.1",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.55.0",
"pyrefly>=0.54.0",
]
############################################################
@@ -247,13 +247,3 @@ module = [
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
]
ignore_errors = true
[tool.pyrefly]
project-includes = ["."]
project-excludes = [
".venv",
"migrations/",
]
python-platform = "linux"
python-version = "3.11.0"
infer-with-first-use = false

View File

@@ -1,200 +0,0 @@
configs/middleware/cache/redis_pubsub_config.py
controllers/console/app/annotation.py
controllers/console/app/app.py
controllers/console/app/app_import.py
controllers/console/app/mcp_server.py
controllers/console/app/site.py
controllers/console/auth/email_register.py
controllers/console/human_input_form.py
controllers/console/init_validate.py
controllers/console/ping.py
controllers/console/setup.py
controllers/console/version.py
controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py
controllers/web/workflow_events.py
core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py
core/external_data_tool/api/api.py
core/llm_generator/llm_generator.py
core/llm_generator/output_parser/structured_output.py
core/mcp/mcp_client.py
core/ops/aliyun_trace/data_exporter/traceclient.py
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
core/ops/mlflow_trace/mlflow_trace.py
core/ops/ops_trace_manager.py
core/ops/tencent_trace/client.py
core/ops/tencent_trace/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/lindorm/lindorm_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/pdf_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/rag/retrieval/router/multi_dataset_function_call_router.py
core/rag/summary_index/summary_index.py
core/repositories/sqlalchemy_workflow_execution_repository.py
core/repositories/sqlalchemy_workflow_node_execution_repository.py
core/tools/__base/tool.py
core/tools/mcp_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/utils/message_transformer.py
core/tools/utils/web_reader_tool.py
core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
dify_graph/variables/types.py
extensions/ext_fastopenapi.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
factories/variable_factory.py
libs/external_api.py
libs/gmpy2_pkcs10aep_cipher.py
libs/helper.py
libs/login.py
libs/module_loading.py
libs/oauth.py
libs/oauth_data_source.py
models/trigger.py
models/workflow.py
repositories/sqlalchemy_api_workflow_node_execution_repository.py
repositories/sqlalchemy_api_workflow_run_repository.py
repositories/sqlalchemy_execution_extra_content_repository.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
services/auth/firecrawl/firecrawl.py
services/auth/jina.py
services/auth/jina/jina.py
services/auth/watercrawl/watercrawl.py
services/conversation_service.py
services/dataset_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
services/external_knowledge_service.py
services/plugin/plugin_migration.py
services/recommend_app/buildin/buildin_retrieval.py
services/recommend_app/database/database_retrieval.py
services/recommend_app/remote/remote_retrieval.py
services/summary_index_service.py
services/tools/tools_transform_service.py
services/trigger/trigger_provider_service.py
services/trigger/trigger_subscription_builder_service.py
services/trigger/webhook_service.py
services/workflow_draft_variable_service.py
services/workflow_event_snapshot_service.py
services/workflow_service.py
tasks/app_generate/workflow_execute_task.py
tasks/regenerate_summary_index_task.py
tasks/trigger_processing_tasks.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py
tasks/workflow_execution_tasks.py

8
api/pyrefly.toml Normal file
View File

@@ -0,0 +1,8 @@
project-includes = ["."]
project-excludes = [
".venv",
"migrations/",
]
python-platform = "linux"
python-version = "3.11.0"
infer-with-first-use = false

View File

@@ -1,6 +1,5 @@
[pytest]
pythonpath = .
addopts = --cov=./api --cov-report=json --import-mode=importlib
addopts = --cov=./api --cov-report=json
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
@@ -20,7 +19,7 @@ env =
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
MOCK_SWITCH = true

View File

@@ -74,16 +74,6 @@ from tasks.mail_reset_password_task import (
logger = logging.getLogger(__name__)
def _try_join_enterprise_default_workspace(account_id: str) -> None:
"""Best-effort join to enterprise default workspace."""
if not dify_config.ENTERPRISE_ENABLED:
return
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(account_id)
class TokenPair(BaseModel):
access_token: str
refresh_token: str
@@ -297,14 +287,13 @@ class AccountService:
email=email, name=name, interface_language=interface_language, password=password
)
try:
TenantService.create_owner_tenant_if_not_exist(account=account)
except Exception:
# Enterprise-only side-effect should run independently from personal workspace creation.
_try_join_enterprise_default_workspace(str(account.id))
raise
TenantService.create_owner_tenant_if_not_exist(account=account)
_try_join_enterprise_default_workspace(str(account.id))
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
if dify_config.ENTERPRISE_ENABLED:
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
return account
@@ -1418,18 +1407,18 @@ class RegisterService:
and create_workspace_required
and FeatureService.get_system_features().license.workspaces.is_available()
):
try:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
except Exception:
_try_join_enterprise_default_workspace(str(account.id))
raise
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
db.session.commit()
_try_join_enterprise_default_workspace(str(account.id))
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
if dify_config.ENTERPRISE_ENABLED:
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
logger.exception("Register failed")

View File

@@ -4,7 +4,6 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@@ -33,7 +32,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig, AppModelConfigDict, IconType
from models.model import AppModelConfig, IconType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@@ -524,7 +523,7 @@ class AppDslService:
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
).from_model_config_dict(model_config)
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id

View File

@@ -38,13 +38,6 @@ if TYPE_CHECKING:
class AppGenerateService:
@staticmethod
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
"""
Build a subscription callback that coordinates when the background task starts.
- streams transport: start immediately (events are durable; late subscribers can replay).
- pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task
still runs if the client never connects.
"""
started = False
lock = threading.Lock()
@@ -61,18 +54,10 @@ class AppGenerateService:
started = True
return True
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
if channel_type == "streams":
# With Redis Streams, we can safely start right away; consumers can read past events.
_try_start()
# Keep return type Callable[[], None] consistent while allowing an extra (no-op) call.
def _on_subscribe_streams() -> None:
_try_start()
return _on_subscribe_streams
# Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback.
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
# The Celery task may publish the first event before the API side actually subscribes,
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
# but also use a short fallback timer so the task still runs if the client never consumes.
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
timer.daemon = True
timer.start()

View File

@@ -1,12 +1,12 @@
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode, AppModelConfigDict
from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict, cast
from typing import TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
@@ -187,7 +187,7 @@ class AppService:
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@@ -388,7 +388,7 @@ class AppService:
agent_config = app_model_config.agent_mode_dict
# get all tools
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
tools = agent_config.get("tools", [])
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"

View File

@@ -2,7 +2,6 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import cast
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@@ -107,7 +106,7 @@ class AudioService:
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = cast(str | None, text_to_speech_dict.get("voice"))
voice = text_to_speech_dict.get("voice")
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(

View File

@@ -130,7 +130,7 @@ class HumanInputService:
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._form_repository = form_repository or HumanInputFormSubmissionRepository()
self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory)
def get_form_by_token(self, form_token: str) -> Form | None:
record = self._form_repository.get_by_token(form_token)

View File

@@ -63,12 +63,7 @@ class RagPipelineTransformService:
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if dataset.tenant_id != current_user.current_tenant_id:
raise ValueError("Unauthorized")
node = self._deal_knowledge_index(
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
)
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
@@ -160,13 +155,14 @@ class RagPipelineTransformService:
def _deal_knowledge_index(
self,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
doc_form: str,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model

Some files were not shown because too many files have changed in this diff Show More