Compare commits

..

7 Commits

Author SHA1 Message Date
Stephen Zhou
c2fc4f4822 Use node 22 2026-03-03 13:31:35 +08:00
Stephen Zhou
67e0eeefd2 Add permission 2026-03-02 21:36:34 +08:00
Stephen Zhou
cf58035fa2 Add heapsnapshot-signal 2026-03-02 19:57:16 +08:00
Stephen Zhou
097f10d920 Merge branch 'hotfix/1.13.0-fix.2' into deploy/memory-usage 2026-03-02 19:33:07 +08:00
wangxiaolei
c34d05141e fix: fix chat assistant response mode blocking is not work (#32394)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-02 18:34:17 +08:00
Stephen Zhou
302701b303 chore: remove serwist 2026-03-02 17:19:15 +08:00
Stephen Zhou
8080159eaf fix: remove REDIRECT_URL_KEY from url (#32770) 2026-03-02 16:06:33 +08:00
1875 changed files with 52139 additions and 165521 deletions

View File

@@ -1,168 +0,0 @@
---
name: backend-code-review
description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review.
---
# Backend Code Review
## When to use this skill
Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes:
- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes).
- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review.
- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`).
Do NOT use this skill when:
- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`).
- The user is not asking for a review/analysis/improvement of backend code.
- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`).
## How to use this skill
Follow these steps when using this skill:
1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the users input. Keep the scope tight: review only what the user provided or explicitly referenced.
2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review.
3. Compose the final output strictly follow the **Required Output Format**.
Notes when using this skill:
- Always include actionable fixes or suggestions (including possible code snippets).
- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can.
## Checklist
- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review
- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review
- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review
- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review
## General Review Rules
### 1. Security Review
Check for:
- SQL injection vulnerabilities
- Server-Side Request Forgery (SSRF)
- Command injection
- Insecure deserialization
- Hardcoded secrets/credentials
- Improper authentication/authorization
- Insecure direct object references
### 2. Performance Review
Check for:
- N+1 queries
- Missing database indexes
- Memory leaks
- Blocking operations in async code
- Missing caching opportunities
### 3. Code Quality Review
Check for:
- Code forward compatibility
- Code duplication (DRY violations)
- Functions doing too much (SRP violations)
- Deep nesting / complex conditionals
- Magic numbers/strings
- Poor naming
- Missing error handling
- Incomplete type coverage
### 4. Testing Review
Check for:
- Missing test coverage for new code
- Tests that don't test behavior
- Flaky test patterns
- Missing edge cases
## Required Output Format
When this skill invoked, the response must exactly follow one of the two templates:
### Template A (any findings)
```markdown
# Code Review Summary
Found <X> critical issues need to be fixed:
## 🔴 Critical (Must Fix)
### 1. <brief description of the issue>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<detailed explanation and references of the issue>
#### Suggested Fix
1. <brief description of suggested fix>
2. <code example> (optional, omit if not applicable)
---
... (repeat for each critical issue) ...
Found <Y> suggestions for improvement:
## 🟡 Suggestions (Should Consider)
### 1. <brief description of the suggestion>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<detailed explanation and references of the suggestion>
#### Suggested Fix
1. <brief description of suggested fix>
2. <code example> (optional, omit if not applicable)
---
... (repeat for each suggestion) ...
Found <Z> optional nits:
## 🟢 Nits (Optional)
### 1. <brief description of the nit>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<explanation and references of the optional nit>
#### Suggested Fix
- <minor suggestions>
---
... (repeat for each nits) ...
## ✅ What's Good
- <Positive feedback on good patterns>
```
- If there are no critical issues or suggestions or option nits or good points, just omit that section.
- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items.
- Don't compress the blank lines between sections; keep them as-is for readability.
- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?"
### Template B (no issues)
```markdown
## Code Review Summary
✅ No issues found.
```

View File

@@ -1,91 +0,0 @@
# Rule Catalog — Architecture
## Scope
- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow.
## Rules
### Keep business logic out of controllers
- Category: maintainability
- Severity: critical
- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test.
- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused.
- Example:
- Bad:
```python
@bp.post("/apps/<app_id>/publish")
def publish_app(app_id: str):
payload = request.get_json() or {}
if payload.get("force") and current_user.role != "admin":
raise ValueError("only admin can force publish")
app = App.query.get(app_id)
app.status = "published"
db.session.commit()
return {"result": "ok"}
```
- Good:
```python
@bp.post("/apps/<app_id>/publish")
def publish_app(app_id: str):
payload = PublishRequest.model_validate(request.get_json() or {})
app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id)
return {"result": "ok"}
```
### Preserve layer dependency direction
- Category: best practices
- Severity: critical
- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code.
- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse.
- Example:
- Bad:
```python
# core/policy/publish_policy.py
from controllers.console.app import request_context
def can_publish() -> bool:
return request_context.current_user.is_admin
```
- Good:
```python
# core/policy/publish_policy.py
def can_publish(role: str) -> bool:
return role == "admin"
# service layer adapts web/user context to domain input
allowed = can_publish(role=current_user.role)
```
### Keep libs business-agnostic
- Category: maintainability
- Severity: critical
- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions.
- Suggested fix:
- If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers.
- Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`.
- Example:
- Bad:
```python
# api/libs/conversation_filter.py
from services.conversation_service import ConversationService
def should_archive_conversation(conversation, tenant_id: str) -> bool:
# Domain policy and service dependency are leaking into libs.
service = ConversationService()
if service.has_paid_plan(tenant_id):
return conversation.idle_days > 90
return conversation.idle_days > 30
```
- Good:
```python
# api/libs/datetime_utils.py (business-agnostic helper)
def older_than_days(idle_days: int, threshold_days: int) -> bool:
return idle_days > threshold_days
# services/conversation_service.py (business logic stays in service/core)
from libs.datetime_utils import older_than_days
def should_archive_conversation(conversation, tenant_id: str) -> bool:
threshold_days = 90 if has_paid_plan(tenant_id) else 30
return older_than_days(conversation.idle_days, threshold_days)
```

View File

@@ -1,157 +0,0 @@
# Rule Catalog — DB Schema Design
## Scope
- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations.
- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`).
## Rules
### Do not query other tables inside `@property`
- Category: [maintainability, performance]
- Severity: critical
- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections.
- Suggested fix:
- Keep model properties pure and local to already-loaded fields.
- Move cross-table data fetching to service/repository methods.
- For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values.
- Example:
- Bad:
```python
class Conversation(TypeBase):
__tablename__ = "conversations"
@property
def app_name(self) -> str:
with Session(db.engine, expire_on_commit=False) as session:
app = session.execute(select(App).where(App.id == self.app_id)).scalar_one()
return app.name
```
- Good:
```python
class Conversation(TypeBase):
__tablename__ = "conversations"
@property
def display_title(self) -> str:
return self.name or "Untitled"
# Service/repository layer performs explicit batch fetch for related App rows.
```
### Prefer including `tenant_id` in model definitions
- Category: maintainability
- Severity: suggestion
- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows.
- Suggested fix:
- Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable.
- Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware.
- Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly.
- Example:
- Bad:
```python
from sqlalchemy.orm import Mapped
class Dataset(TypeBase):
__tablename__ = "datasets"
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
```
- Good:
```python
from sqlalchemy.orm import Mapped
class Dataset(TypeBase):
__tablename__ = "datasets"
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
```
### Detect and avoid duplicate/redundant indexes
- Category: performance
- Severity: suggestion
- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans.
- Suggested fix:
- Before adding an index, compare against existing composite indexes by leftmost-prefix rules.
- Drop or avoid creating redundant prefixes unless there is a proven query-pattern need.
- Apply the same review standard in both model `__table_args__` and migration index DDL.
- Example:
- Bad:
```python
__table_args__ = (
sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"),
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
)
```
- Good:
```python
__table_args__ = (
# Keep the wider index unless profiling proves a dedicated short index is needed.
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
)
```
### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types`
- Category: maintainability
- Severity: critical
- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code.
- Suggested fix:
- Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used.
- Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL.
- Example:
- Bad:
```python
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
class ToolConfig(TypeBase):
__tablename__ = "tool_configs"
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
```
- Good:
```python
from sqlalchemy.orm import Mapped
from models.types import AdjustedJSON
class ToolConfig(TypeBase):
__tablename__ = "tool_configs"
config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False)
```
### Guard migration incompatibilities with dialect checks and shared types
- Category: maintainability
- Severity: critical
- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable.
- Suggested fix:
- In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments.
- Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models.
- Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception.
- Example:
- Bad:
```python
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
batch_op.add_column(
sa.Column(
"data_source_type",
sa.String(255),
server_default=sa.text("'database'::character varying"),
nullable=False,
)
)
```
- Good:
```python
def _is_pg(conn) -> bool:
return conn.dialect.name == "postgresql"
conn = op.get_bind()
default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'")
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
batch_op.add_column(
sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False)
)
```

View File

@@ -1,61 +0,0 @@
# Rule Catalog - Repositories Abstraction
## Scope
- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations.
- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`).
## Rules
### Introduce repositories abstraction
- Category: maintainability
- Severity: suggestion
- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation.
- Suggested fix:
- First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access.
- If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation).
- Example:
- Bad:
```python
# Existing repository is ignored and service uses ad-hoc table queries.
class AppService:
def archive_app(self, app_id: str, tenant_id: str) -> None:
app = self.session.execute(
select(App).where(App.id == app_id, App.tenant_id == tenant_id)
).scalar_one()
app.archived = True
self.session.commit()
```
- Good:
```python
# Case A: Existing repository must be reused for all table operations.
class AppService:
def archive_app(self, app_id: str, tenant_id: str) -> None:
app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id)
app.archived = True
self.app_repo.save(app)
# If the query is missing, extend the existing abstraction.
active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id)
```
- Bad:
```python
# No repository exists, but large-domain query logic is scattered in service code.
class ConversationService:
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
...
# many filters/joins/pagination variants duplicated across services
```
- Good:
```python
# Case B: Introduce repository for large/complex domains or storage variation.
class ConversationRepository(Protocol):
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ...
class SqlAlchemyConversationRepository:
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
...
class ConversationService:
def __init__(self, conversation_repo: ConversationRepository):
self.conversation_repo = conversation_repo
```

View File

@@ -1,139 +0,0 @@
# Rule Catalog — SQLAlchemy Patterns
## Scope
- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards.
- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`).
## Rules
### Use Session context manager with explicit transaction control behavior
- Category: best practices
- Severity: critical
- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk.
- Suggested fix:
- Use **explicit `session.commit()`** after completing a related write unit.
- Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block.
- Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction.
- Example:
- Bad:
```python
# Missing commit: write may never be persisted.
with Session(db.engine, expire_on_commit=False) as session:
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
# Long transaction: external I/O inside a DB transaction.
with Session(db.engine, expire_on_commit=False) as session, session.begin():
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
call_external_api()
```
- Good:
```python
# Option 1: explicit commit.
with Session(db.engine, expire_on_commit=False) as session:
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
session.commit()
# Option 2: scoped transaction with automatic commit/rollback.
with Session(db.engine, expire_on_commit=False) as session, session.begin():
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
# Keep non-DB work outside transaction scope.
call_external_api()
```
### Enforce tenant_id scoping on shared-resource queries
- Category: security
- Severity: critical
- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption.
- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces.
- Example:
- Bad:
```python
stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Good:
```python
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
### Prefer SQLAlchemy expressions over raw SQL by default
- Category: maintainability
- Severity: suggestion
- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase.
- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints.
- Example:
- Bad:
```python
row = session.execute(
text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"),
{"id": workflow_id, "tenant_id": tenant_id},
).first()
```
- Good:
```python
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
row = session.execute(stmt).scalar_one_or_none()
```
### Protect write paths with concurrency safeguards
- Category: quality
- Severity: critical
- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy.
- Suggested fix:
- **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict.
- **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion.
- **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk.
- In all cases, scope by `tenant_id` and verify affected row counts for conditional writes.
- Example:
- Bad:
```python
# No tenant scope, no conflict detection, and no lock on a contested write path.
session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled"))
session.commit() # silently overwrites concurrent updates
```
- Good:
```python
# 1) Optimistic lock (low contention, retry on conflict)
result = session.execute(
update(WorkflowRun)
.where(
WorkflowRun.id == run_id,
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.version == expected_version,
)
.values(status="cancelled", version=WorkflowRun.version + 1)
)
if result.rowcount == 0:
raise WorkflowStateConflictError("stale version, retry")
# 2) Redis distributed lock (cross-worker critical section)
lock_name = f"workflow_run_lock:{tenant_id}:{run_id}"
with redis_client.lock(lock_name, timeout=20):
session.execute(
update(WorkflowRun)
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
.values(status="cancelled")
)
session.commit()
# 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention)
run = session.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
.with_for_update()
).scalar_one()
run.status = "cancelled"
session.commit()
```

View File

@@ -1 +0,0 @@
../../.agents/skills/backend-code-review

View File

@@ -1,25 +1,12 @@
version: 2
multi-ecosystem-groups:
python:
schedule:
interval: "weekly" # or whatever schedule you want
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 2
patterns: ["*"]
schedule:
interval: "weekly"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 2
patterns: ["*"]
schedule:
interval: "weekly"
- package-ecosystem: "npm"
directory: "/web"
schedule:
interval: "weekly"
open-pull-requests-limit: 2
- package-ecosystem: "uv"
directory: "/api"
schedule:
interval: "weekly"
open-pull-requests-limit: 2

View File

@@ -1,88 +0,0 @@
name: Comment with Pyrefly Diff
on:
workflow_run:
workflows:
- Pyrefly Diff Check
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with pyrefly diff
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@v8
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_diff'
);
if (!match) {
throw new Error('pyrefly_diff artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_diff.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@v8
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
// Fallback to workflow_run payload if artifact is missing or incomplete.
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
const MAX_CHARS = 65000;
if (diff.length > MAX_CHARS) {
diff = diff.slice(0, MAX_CHARS);
diff = diff.slice(0, diff.lastIndexOf('\\n'));
diff += '\\n\\n... (truncated) ...';
}
const body = diff.trim()
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});

View File

@@ -1,94 +0,0 @@
name: Pyrefly Diff Check
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-diff:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly on PR branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly on base branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
- name: Compute diff
run: |
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@v4
with:
name: pyrefly_diff
path: |
pyrefly_diff.txt
pr_number.txt
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@v8
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
const prNumber = context.payload.pull_request.number;
const MAX_CHARS = 65000;
if (diff.length > MAX_CHARS) {
diff = diff.slice(0, MAX_CHARS);
diff = diff.slice(0, diff.lastIndexOf('\n'));
diff += '\n\n... (truncated) ...';
}
const body = diff.trim()
? [
'### Pyrefly Diff',
'<details>',
'<summary>base → PR</summary>',
'',
'```diff',
diff,
'```',
'</details>',
].join('\n')
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});

View File

@@ -3,22 +3,14 @@ name: Web Tests
on:
workflow_call:
permissions:
contents: read
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
name: Web Tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
defaults:
run:
shell: bash
@@ -47,58 +39,7 @@ jobs:
run: pnpm install --frozen-lockfile
- name: Run tests
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v6
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*
include-hidden-files: true
retention-days: 1
merge-reports:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Download blob reports
uses: actions/download-artifact@v6
with:
path: web/.vitest-reports
pattern: blob-report-*
merge-multiple: true
- name: Merge reports
run: pnpm vitest --merge-reports --coverage --silent=passed-only
run: pnpm test:ci
- name: Coverage Summary
if: always()

View File

@@ -1,5 +1,9 @@
![cover-v5-optimized](./images/GitHub_README_if.png)
<p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
</p>
<p align="center">
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·

View File

@@ -553,8 +553,6 @@ WORKFLOW_LOG_CLEANUP_ENABLED=false
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# Comma-separated list of workflow IDs to clean logs for
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS=
# App configuration
APP_MAX_EXECUTION_TIME=1200

View File

@@ -50,11 +50,14 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@@ -88,6 +91,7 @@ forbidden_modules =
core.logging
core.mcp
core.memory
core.model_manager
core.moderation
core.ops
core.plugin
@@ -101,18 +105,29 @@ forbidden_modules =
core.variables
ignore_imports =
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability
core.workflow.nodes.agent.agent_node -> core.model_manager
core.workflow.nodes.agent.agent_node -> core.provider_manager
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
core.workflow.nodes.datasource.datasource_node -> models.model
core.workflow.nodes.datasource.datasource_node -> models.tools
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
core.workflow.nodes.document_extractor.node -> configs
core.workflow.nodes.document_extractor.node -> core.file.file_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.entities -> configs
core.workflow.nodes.http_request.executor -> configs
core.workflow.nodes.http_request.executor -> core.file.file_manager
core.workflow.nodes.http_request.node -> configs
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.file.models
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
@@ -142,18 +157,47 @@ ignore_imports =
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.node_events.node -> core.file
core.workflow.nodes.agent.agent_node -> core.file
core.workflow.nodes.datasource.datasource_node -> core.file
core.workflow.nodes.datasource.datasource_node -> core.file.enums
core.workflow.nodes.document_extractor.node -> core.file
core.workflow.nodes.http_request.executor -> core.file.enums
core.workflow.nodes.http_request.node -> core.file
core.workflow.nodes.http_request.node -> core.file.file_manager
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
core.workflow.nodes.list_operator.node -> core.file
core.workflow.nodes.llm.file_saver -> core.file
core.workflow.nodes.llm.llm_utils -> core.variables.segments
core.workflow.nodes.llm.node -> core.file
core.workflow.nodes.llm.node -> core.file.file_manager
core.workflow.nodes.llm.node -> core.file.models
core.workflow.nodes.loop.entities -> core.variables.types
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
core.workflow.nodes.protocols -> core.file
core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
core.workflow.nodes.tool.tool_node -> core.file
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
core.workflow.nodes.tool.tool_node -> models
core.workflow.nodes.trigger_webhook.node -> core.file
core.workflow.runtime.variable_pool -> core.file
core.workflow.runtime.variable_pool -> core.file.file_manager
core.workflow.system_variable -> core.file.models
core.workflow.utils.condition.processor -> core.file
core.workflow.utils.condition.processor -> core.file.file_manager
core.workflow.workflow_entry -> core.file.models
core.workflow.workflow_type_encoder -> core.file.models
core.workflow.nodes.agent.agent_node -> models.model
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
@@ -190,6 +234,7 @@ ignore_imports =
core.workflow.nodes.code.code_node -> core.variables.segments
core.workflow.nodes.code.code_node -> core.variables.types
core.workflow.nodes.code.entities -> core.variables.types
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
core.workflow.nodes.document_extractor.node -> core.variables
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
@@ -231,7 +276,9 @@ ignore_imports =
core.workflow.variable_loader -> core.variables
core.workflow.variable_loader -> core.variables.consts
core.workflow.workflow_type_encoder -> core.variables
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
@@ -247,11 +294,6 @@ ignore_imports =
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services
core.workflow.nodes.tool.tool_node -> services
core.workflow.nodes.agent.agent_node -> core.model_runtime.token_buffer_memory
core.workflow.nodes.llm.llm_utils -> core.model_runtime.token_buffer_memory
core.workflow.nodes.llm.node -> core.model_runtime.token_buffer_memory
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.token_buffer_memory
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.token_buffer_memory
[importlinter:contract:model-runtime-no-internal-imports]
name = Model Runtime Internal Imports
@@ -304,13 +346,6 @@ ignore_imports =
core.model_runtime.model_providers.model_provider_factory -> configs
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
core.model_runtime.token_buffer_memory -> core.app.app_config.features.file_upload.manager
core.model_runtime.token_buffer_memory -> core.model_manager
core.model_runtime.token_buffer_memory -> core.prompt.utils.extract_thread_messages
core.model_runtime.token_buffer_memory -> core.workflow.file.file_manager
core.model_runtime.token_buffer_memory -> extensions.ext_database
core.model_runtime.token_buffer_memory -> models.model
core.model_runtime.token_buffer_memory -> models.workflow
[importlinter:contract:rsc]
name = RSC

View File

@@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
1. Set up your application by visiting `http://localhost:3000`.
1. Start the worker service (async and scheduler tasks, runs from `api`).
1. Optional: start the worker service (async tasks, runs from `api`).
```bash
./dev/start-worker
@@ -54,6 +54,86 @@ The scripts resolve paths relative to their location, so you can run them from a
./dev/start-beat
```
### Manual commands
<details>
<summary>Show manual setup and run steps</summary>
These commands assume you start from the repository root.
1. Start the docker-compose stack.
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
```bash
cp docker/middleware.env.example docker/middleware.env
# Use mysql or another vector database profile if you are not using postgres/weaviate.
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
```
1. Copy env files.
```bash
cp api/.env.example api/.env
cp web/.env.example web/.env.local
```
1. Install UV if needed.
```bash
pip install uv
# Or on macOS
brew install uv
```
1. Install API dependencies.
```bash
cd api
uv sync --group dev
```
1. Install web dependencies.
```bash
cd web
pnpm install
cd ..
```
1. Start backend (runs migrations first, in a new terminal).
```bash
cd api
uv run flask db upgrade
uv run flask run --host 0.0.0.0 --port=5001 --debug
```
1. Start Dify [web](../web) service (in a new terminal).
```bash
cd web
pnpm dev:inspect
```
1. Set up your application by visiting `http://localhost:3000`.
1. Optional: start the worker service (async tasks, in a new terminal).
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
```bash
cd api
uv run celery -A app.celery beat
```
</details>
### Environment notes
> [!IMPORTANT]

View File

@@ -30,7 +30,6 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
@@ -55,8 +54,6 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
DB_UPGRADE_LOCK_TTL_SECONDS = 60
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@@ -730,15 +727,8 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
@click.command("upgrade-db", help="Upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
lock = DbMigrationAutoRenewLock(
redis_client=redis_client,
name="db_upgrade_lock",
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
logger=logger,
log_context="db_migration",
)
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
migration_succeeded = False
try:
click.echo(click.style("Starting database migration.", fg="green"))
@@ -747,7 +737,6 @@ def upgrade_db():
flask_migrate.upgrade()
migration_succeeded = True
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
@@ -755,8 +744,7 @@ def upgrade_db():
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
status = "successful" if migration_succeeded else "failed"
lock.release_safely(status=status)
lock.release()
else:
click.echo("Database migration skipped")

View File

@@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
default=60 * 60,
)
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
description="Maximum allowed size (bytes) for plugin-generated files",
default=50 * 1024 * 1024,
)
class MarketplaceConfig(BaseSettings):
"""
@@ -1319,9 +1314,6 @@ class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
)
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS: str = Field(
default="", description="Comma-separated list of workflow IDs to clean logs for"
)
class SwaggerUIConfig(BaseSettings):

View File

@@ -1,5 +1,3 @@
from typing import Literal
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@@ -51,43 +49,3 @@ class OceanBaseVectorConfig(BaseSettings):
),
default="ik",
)
OCEANBASE_VECTOR_BATCH_SIZE: PositiveInt = Field(
description="Number of documents to insert per batch",
default=100,
)
OCEANBASE_VECTOR_METRIC_TYPE: Literal["l2", "cosine", "inner_product"] = Field(
description="Distance metric type for vector index: l2, cosine, or inner_product",
default="l2",
)
OCEANBASE_HNSW_M: PositiveInt = Field(
description="HNSW M parameter (max number of connections per node)",
default=16,
)
OCEANBASE_HNSW_EF_CONSTRUCTION: PositiveInt = Field(
description="HNSW efConstruction parameter (index build-time search width)",
default=256,
)
OCEANBASE_HNSW_EF_SEARCH: int = Field(
description="HNSW efSearch parameter (query-time search width, -1 uses server default)",
default=-1,
)
OCEANBASE_VECTOR_POOL_SIZE: PositiveInt = Field(
description="SQLAlchemy connection pool size",
default=5,
)
OCEANBASE_VECTOR_MAX_OVERFLOW: int = Field(
description="SQLAlchemy connection pool max overflow connections",
default=10,
)
OCEANBASE_HNSW_REFRESH_THRESHOLD: int = Field(
description="Minimum number of inserted documents to trigger an automatic HNSW index refresh (0 to disable)",
default=1000,
)

View File

@@ -4,7 +4,7 @@ from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, computed_field
from core.workflow.file import helpers as file_helpers
from core.file import helpers as file_helpers
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]

View File

@@ -23,10 +23,10 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.file import helpers as file_helpers
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
@@ -660,19 +660,6 @@ class AppCopyApi(Resource):
)
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
try:
# Get the original app's access mode
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
access_mode = original_settings.access_mode
except Exception:
# If original app has no settings (old app), default to public to match fallback behavior
access_mode = "public"
# Apply the same access mode to the copied app
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)

View File

@@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginInvokeError
@@ -30,10 +31,8 @@ from core.trigger.debug.event_selectors import (
select_trigger_debug_events,
)
from core.workflow.enums import NodeType
from core.workflow.file.models import File
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@@ -741,7 +740,7 @@ class WorkflowTaskStopApi(Resource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@@ -15,11 +15,11 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.file import helpers as file_helpers
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
@@ -112,11 +112,11 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = {
**_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
"value": fields.Raw(attribute=_serialize_var_value),
"full_content": fields.Raw(attribute=_serialize_full_content),
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
full_content=fields.Raw(attribute=_serialize_full_content),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,

View File

@@ -10,7 +10,7 @@ import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -44,7 +44,6 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.app_fields import (
app_detail_fields_with_site,
deleted_tool_fields,
@@ -226,7 +225,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
@@ -470,7 +469,7 @@ class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial(None)
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app site info.
@@ -492,7 +491,7 @@ class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial(None)
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app parameters."""
@@ -521,7 +520,7 @@ class TrialAppParameterApi(Resource):
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@get_app_model_with_trial
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
"""Get app detail"""
@@ -534,7 +533,7 @@ class AppApi(Resource):
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@get_app_model_with_trial
@marshal_with(workflow_model)
def get(self, app_model):
"""Get workflow detail"""
@@ -553,7 +552,7 @@ class AppWorkflowApi(Resource):
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial(None)
@get_app_model_with_trial
def get(self, app_model):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
@@ -571,31 +570,27 @@ class DatasetListApi(Resource):
return response
console_ns.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
console_ns.add_resource(
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
console_ns.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
console_ns.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
console_ns.add_resource(
TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion"
)
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
console_ns.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
console_ns.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
console_ns.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
console_ns.add_resource(
TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run"
)
console_ns.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
console_ns.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
console_ns.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")

View File

@@ -23,7 +23,6 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from libs import helper
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
@@ -101,6 +100,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@@ -105,9 +105,9 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
return decorator
def trial_feature_enable(view: Callable[P, R]):
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
@@ -116,9 +116,9 @@ def trial_feature_enable(view: Callable[P, R]):
return decorated
def explore_banner_enabled(view: Callable[P, R]):
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")

View File

@@ -12,8 +12,8 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.console import console_ns
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from core.workflow.file import helpers as file_helpers
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant, login_required

View File

@@ -137,7 +137,7 @@ class FilePreviewApi(Resource):
if args.as_attachment:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,

View File

@@ -64,10 +64,6 @@ class ToolFileApi(Resource):
if not stream or not tool_file:
raise NotFound("file is not found")
except NotFound:
raise
except Exception:
raise UnsupportedFileTypeError()

View File

@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
import services
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file.helpers import verify_plugin_file_signature
from fields.file_fields import FileResponse
from ..common.errors import (

View File

@@ -4,6 +4,7 @@ from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.file.helpers import get_signed_file_url_for_plugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.file.helpers import get_signed_file_url_for_plugin
from libs.helper import length_prefixed_response
from models import Account, Tenant
from models.model import EndUser

View File

@@ -31,7 +31,6 @@ from core.model_runtime.errors.invoke import InvokeError
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
@@ -281,7 +280,7 @@ class WorkflowTaskStopApi(Resource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@@ -3,8 +3,7 @@ from typing import Any
from flask import request
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
@@ -18,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models import Account
from models.dataset import Dataset, Pipeline
from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
@@ -66,12 +65,6 @@ class DatasourcePluginsApi(DatasetApiResource):
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
@@ -111,12 +104,6 @@ class DatasourceNodeRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
@@ -174,12 +161,6 @@ class PipelineRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):

View File

@@ -10,8 +10,8 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from core.workflow.file import helpers as file_helpers
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from services.file_service import FileService

View File

@@ -24,7 +24,6 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@@ -122,6 +121,6 @@ class WorkflowTaskStopApi(WebApiResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@@ -17,6 +17,8 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -31,7 +33,6 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
@@ -39,7 +40,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from core.workflow.file import file_manager
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
@@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []

View File

@@ -245,7 +245,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model_name,
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
@@ -268,7 +268,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model_name,
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),

View File

@@ -1,6 +1,7 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@@ -10,7 +11,6 @@ from core.model_runtime.entities import (
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.file import file_manager
class CotChatAgentRunner(CotAgentRunner):

View File

@@ -7,6 +7,7 @@ from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
@@ -24,7 +25,6 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.file import file_manager
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
@@ -178,7 +178,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
yield LLMResultChunk(
model=model_instance.model_name,
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
@@ -308,7 +308,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model_name,
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),

View File

@@ -5,9 +5,9 @@ from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
from models.model import AppMode

View File

@@ -2,7 +2,7 @@ from collections.abc import Mapping
from typing import Any
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.workflow.file import FileUploadConfig
from core.file import FileUploadConfig
class FileUploadConfigManager:

View File

@@ -669,14 +669,16 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
yield from ()
return
yield # Make this a generator
def _handle_annotation_reply_event(
self, event: QueueAnnotationReplyEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
yield from ()
return
yield # Make this a generator
def _handle_message_replace_event(
self, event: QueueMessageReplaceEvent, **kwargs

View File

@@ -12,11 +12,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.moderation.base import ModerationError
from extensions.ext_database import db
from models.model import App, Conversation, Message
@@ -178,7 +178,7 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")

View File

@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.enums import NodeType
from core.workflow.file import File, FileUploadConfig
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,

View File

@@ -2,7 +2,7 @@ import logging
import queue
import threading
import time
from abc import ABC, abstractmethod
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any
@@ -31,7 +31,7 @@ class PublishFrom(IntEnum):
TASK_PIPELINE = auto()
class AppQueueManager(ABC):
class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id:
raise ValueError("user is required")
@@ -122,7 +122,7 @@ class AppQueueManager(ABC):
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@@ -133,7 +133,7 @@ class AppQueueManager(ABC):
self._publish(event, pub_from)
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

View File

@@ -22,6 +22,8 @@ from core.app.entities.queue_entities import (
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.enums import FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
@@ -32,19 +34,17 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file.enums import FileTransferMethod, FileType
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
from core.workflow.file.models import File
from core.file.models import File
_logger = logging.getLogger(__name__)

View File

@@ -11,12 +11,12 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file import File
from extensions.ext_database import db
from models.model import App, Conversation, Message

View File

@@ -45,6 +45,7 @@ from core.app.entities.task_entities import (
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
@@ -59,7 +60,6 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.file import FILE_MODEL_IDENTITY, File
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry

View File

@@ -10,11 +10,11 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file import File
from extensions.ext_database import db
from models.model import App, Message

View File

@@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
from core.workflow.file import File, FileUploadConfig
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager

View File

@@ -1 +0,0 @@
"""LLM-related application services."""

View File

@@ -1,103 +0,0 @@
from __future__ import annotations
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
raise ValueError(f"Provider {provider_name} does not exist.")
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
provider_model.raise_for_status()
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
return credentials
class DifyModelFactory:
tenant_id: str
model_manager: ModelManager
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType.LLM,
model=model_name,
)
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
return (
DifyCredentialsProvider(tenant_id=tenant_id),
DifyModelFactory(tenant_id=tenant_id),
)
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)

View File

@@ -45,6 +45,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.file import helpers as file_helpers
from core.file.enums import FileTransferMethod
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
@@ -57,8 +59,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from core.workflow.file import helpers as file_helpers
from core.workflow.file.enums import FileTransferMethod
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@@ -157,7 +157,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
@@ -170,7 +170,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
@@ -283,7 +283,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
self._task_state.llm_result.message.get_text_content()
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -397,7 +397,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
if llm_result.message.content
else ""
)

View File

@@ -1,47 +0,0 @@
from __future__ import annotations
from collections.abc import Generator
from configs import dify_config
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from core.workflow.file.runtime import set_workflow_file_runtime
from extensions.ext_storage import storage
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
"""Production runtime wiring for ``core.workflow.file``."""
@property
def files_url(self) -> str:
return dify_config.FILES_URL
@property
def internal_files_url(self) -> str | None:
return dify_config.INTERNAL_FILES_URL
@property
def secret_key(self) -> str:
return dify_config.SECRET_KEY
@property
def files_access_timeout(self) -> int:
return dify_config.FILES_ACCESS_TIMEOUT
@property
def multimodal_send_format(self) -> str:
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
def bind_dify_workflow_file_runtime() -> None:
set_workflow_file_runtime(DifyWorkflowFileRuntime())

View File

@@ -1,34 +1,28 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, final
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
from core.file.file_manager import file_manager
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
from core.workflow.nodes.code.entities import CodeLanguage
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
@@ -37,24 +31,6 @@ if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
class DefaultWorkflowCodeExecutor:
def execute(
self,
*,
language: CodeLanguage,
code: str,
inputs: Mapping[str, Any],
) -> Mapping[str, Any]:
return CodeExecutor.execute_workflow_code_template(
language=language,
code=code,
inputs=inputs,
)
def is_execution_error(self, error: Exception) -> bool:
return isinstance(error, CodeExecutionError)
@final
class DifyNodeFactory(NodeFactory):
"""
@@ -68,12 +44,23 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
template_transform_max_output_length: int | None = None,
http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers()
self._code_limits = CodeNodeLimits(
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
@@ -83,27 +70,14 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._http_request_file_manager = file_manager
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = (
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
)
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._rag_retrieval = DatasetRetrieval()
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
)
self._http_request_config = build_http_request_config(
max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
max_write_timeout=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
max_binary_size=dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE,
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@@ -164,31 +138,11 @@ class DifyNodeFactory(NodeFactory):
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_request_config=self._http_request_config,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.LLM:
return LLMNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
@@ -198,35 +152,6 @@ class DifyNodeFactory(NodeFactory):
rag_retrieval=self._rag_retrieval,
)
if node_type == NodeType.DOCUMENT_EXTRACTOR:
return DocumentExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
return QuestionClassifierNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
return ParameterExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
return node_class(
id=node_id,
config=node_config,

View File

@@ -213,6 +213,6 @@ class DatasourceFileManager:
# init tool_file_parser
# from core.workflow.file.datasource_file_parser import datasource_file_manager
# from core.file.datasource_file_parser import datasource_file_manager
#
# datasource_file_manager["manager"] = DatasourceFileManager

View File

@@ -1,39 +1,16 @@
import logging
from collections.abc import Generator
from threading import Lock
from typing import Any, cast
from sqlalchemy import select
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.file import File
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@@ -126,238 +103,3 @@ class DatasourceManager:
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
datasource_runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
return datasource_runtime.get_icon_url(tenant_id)
@classmethod
def stream_online_results(
cls,
*,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[DatasourceMessage, None, Any]:
"""
Pull-based streaming of domain messages from datasource plugins.
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
"""
ds_type = DatasourceProviderType.value_of(datasource_type)
runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=ds_type,
)
dsp_service = DatasourceProviderService()
credentials = dsp_service.get_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
if credentials:
doc_runtime.runtime.credentials = credentials
if datasource_param is None:
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
user_id=user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_param.workspace_id,
page_id=datasource_param.page_id,
type=datasource_param.type,
),
provider_type=ds_type,
)
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
if credentials:
drive_runtime.runtime.credentials = credentials
if online_drive_request is None:
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
inner_gen = drive_runtime.online_drive_download_file(
user_id=user_id,
request=OnlineDriveDownloadFileRequest(
id=online_drive_request.id,
bucket=online_drive_request.bucket,
),
provider_type=ds_type,
)
else:
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
# Bridge through to caller while preserving generator return contract
yield from inner_gen
# No structured final data here; node/adapter will assemble outputs
return {}
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
ds_type = DatasourceProviderType.value_of(datasource_type)
messages = cls.stream_online_results(
user_id=user_id,
datasource_name=datasource_name,
datasource_type=datasource_type,
provider_id=provider_id,
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
datasource_param=datasource_param,
online_drive_request=online_drive_request,
)
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
)
variables: dict[str, Any] = {}
file_out: File | None = None
for message in transformed:
mtype = message.type
if mtype in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
wanted_ds_type = ds_type in {
DatasourceProviderType.ONLINE_DRIVE,
DatasourceProviderType.ONLINE_DOCUMENT,
}
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
url = message.message.text
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with session_factory.create_session() as session:
stmt = select(ToolFile).where(
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
)
datasource_file = session.scalar(stmt)
if not datasource_file:
raise ValueError(
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
)
mime_type = datasource_file.mimetype
if datasource_file is not None:
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(mime_type),
"transfer_method": FileTransferMethod.TOOL_FILE,
"url": url,
}
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
elif mtype == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
elif mtype == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
)
elif mtype == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
name = message.message.variable_name
value = message.message.variable_value
if message.message.stream:
assert isinstance(value, str), "stream variable_value must be str"
variables[name] = variables.get(name, "") + value
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
else:
variables[name] = value
elif mtype == DatasourceMessage.MessageType.FILE:
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
f = message.meta.get("file")
if isinstance(f, File):
file_out = f
else:
pass
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
variable_pool.add([node_id, "file"], file_out)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={**variables},
)
)
else:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_out,
"datasource_type": ds_type,
},
)
)
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
return file_info

View File

@@ -379,11 +379,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
"""
id: str = Field(..., description="The id of the file")
bucket: str = Field("", description="The name of the bucket")
@field_validator("bucket", mode="before")
@classmethod
def _coerce_bucket(cls, v) -> str:
if v is None:
return ""
return str(v)
bucket: str | None = Field(None, description="The name of the bucket")

View File

@@ -3,8 +3,8 @@ from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file import File, FileTransferMethod, FileType
from models.tools import ToolFile
logger = logging.getLogger(__name__)

View File

@@ -10,12 +10,12 @@ from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.file import helpers as file_helpers
if TYPE_CHECKING:
from models.tools import MCPToolProvider

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import base64
from collections.abc import Mapping
from configs import dify_config
from core.helper import ssrf_proxy
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
@@ -11,11 +11,12 @@ from core.model_runtime.entities import (
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
from .runtime import get_workflow_file_runtime
def get_attr(*, file: File, attr: FileAttribute):
@@ -44,7 +45,26 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes:
"""Convert a file to prompt message content."""
"""
Convert a file to prompt message content.
This function converts files to their appropriate prompt message content types.
For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
corresponding message content with proper encoding/URL.
For unsupported file types, instead of raising an error, it returns a
TextPromptMessageContent with a descriptive message about the file.
Args:
f: The file to convert
image_detail_config: Optional detail configuration for image files
Returns:
PromptMessageContentUnionTypes: The appropriate message content type
Raises:
ValueError: If file extension or mime_type is missing
"""
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
@@ -57,13 +77,15 @@ def to_prompt_message_content(
FileType.DOCUMENT: DocumentPromptMessageContent,
}
# Check if file type is supported
if f.type not in prompt_class_map:
# For unsupported file types, return a text description
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
send_format = get_workflow_file_runtime().multimodal_send_format
# Process supported file types
params = {
"base64_data": _get_encoded_string(f) if send_format == "base64" else "",
"url": _to_url(f) if send_format == "url" else "",
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
@@ -74,7 +96,7 @@ def to_prompt_message_content(
return prompt_class_map[f.type].model_validate(params)
def download(f: File, /) -> bytes:
def download(f: File, /):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
@@ -84,26 +106,39 @@ def download(f: File, /) -> bytes:
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
def _download_file_content(path: str, /) -> bytes:
"""Download and return a file from storage as bytes."""
data = get_workflow_file_runtime().storage_load(path, stream=False)
def _download_file_content(path: str, /):
"""
Download and return the contents of a file as bytes.
This function loads the file from storage and ensures it's in bytes format.
Args:
path (str): The path to the file in storage.
Returns:
bytes: The contents of the file as a bytes object.
Raises:
ValueError: If the loaded file is not a bytes object.
"""
data = storage.load(path, stream=False)
if not isinstance(data, bytes):
raise ValueError(f"file {path} is not a bytes object")
return data
def _get_encoded_string(f: File, /) -> str:
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
@@ -113,7 +148,8 @@ def _get_encoded_string(f: File, /) -> str:
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key)
return base64.b64encode(data).decode("utf-8")
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
def _to_url(f: File, /):
@@ -126,15 +162,21 @@ def _to_url(f: File, /):
raise ValueError("Missing file related_id")
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
# add sign url
if f.related_id is None or f.extension is None:
raise ValueError("Missing file related_id or extension")
return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension)
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
class FileManager:
"""Adapter exposing file manager helpers behind FileManagerProtocol."""
"""
Adapter exposing file manager helpers behind FileManagerProtocol.
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
where a protocol-typed file manager is expected.
"""
def download(self, f: File, /) -> bytes:
return download(f)

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import base64
import hashlib
import hmac
@@ -7,21 +5,20 @@ import os
import time
import urllib.parse
from .runtime import get_workflow_file_runtime
from configs import dify_config
def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str:
runtime = get_workflow_file_runtime()
base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url)
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = runtime.secret_key.encode()
key = dify_config.SECRET_KEY.encode()
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign}
query = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign}
if as_attachment:
query["as_attachment"] = "true"
query_string = urllib.parse.urlencode(query)
@@ -30,63 +27,57 @@ def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_ex
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
runtime = get_workflow_file_runtime()
# Plugin access should use internal URL for Docker network communication.
base_url = runtime.internal_files_url or runtime.files_url
# Plugin access should use internal URL for Docker network communication
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
url = f"{base_url}/files/upload/for-plugin"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = runtime.secret_key.encode()
key = dify_config.SECRET_KEY.encode()
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str:
runtime = get_workflow_file_runtime()
return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
def verify_plugin_file_signature(
*, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str
) -> bool:
runtime = get_workflow_file_runtime()
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = runtime.secret_key.encode()
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= runtime.files_access_timeout
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
runtime = get_workflow_file_runtime()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = runtime.secret_key.encode()
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= runtime.files_access_timeout
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
runtime = get_workflow_file_runtime()
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = runtime.secret_key.encode()
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= runtime.files_access_timeout
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View File

@@ -1,26 +1,16 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.tools.signature import sign_tool_file
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``."""
return helpers.get_signed_tool_file_url(
tool_file_id=tool_file_id,
extension=extension,
for_external=for_external,
)
class ImageConfig(BaseModel):
"""
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
@@ -132,11 +122,7 @@ class File(BaseModel):
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
return sign_tool_file(
tool_file_id=self.related_id,
extension=self.extension,
for_external=for_external,
)
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
@@ -151,7 +137,7 @@ class File(BaseModel):
}
@model_validator(mode="after")
def validate_after(self) -> File:
def validate_after(self):
match self.transfer_method:
case FileTransferMethod.REMOTE_URL:
if not self.remote_url:
@@ -174,5 +160,5 @@ class File(BaseModel):
return self._storage_key
@storage_key.setter
def storage_key(self, value: str) -> None:
def storage_key(self, value: str):
self._storage_key = value

View File

@@ -0,0 +1,12 @@
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]):
global _tool_file_manager_factory
_tool_file_manager_factory = factory

View File

@@ -4,6 +4,7 @@ from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -15,7 +16,6 @@ from core.model_runtime.entities import (
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.workflow.file import file_manager
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile

View File

@@ -35,7 +35,7 @@ class ModelInstance:
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.model = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self.provider_model_bundle.model_type_instance
@@ -163,7 +163,7 @@ class ModelInstance:
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
@@ -191,7 +191,7 @@ class ModelInstance:
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
@@ -215,7 +215,7 @@ class ModelInstance:
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user,
@@ -243,7 +243,7 @@ class ModelInstance:
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
@@ -264,7 +264,7 @@ class ModelInstance:
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
model=self.model,
credentials=self.credentials,
texts=texts,
),
@@ -294,7 +294,7 @@ class ModelInstance:
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
@@ -328,7 +328,7 @@ class ModelInstance:
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
@@ -352,7 +352,7 @@ class ModelInstance:
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
text=text,
user=user,
@@ -373,7 +373,7 @@ class ModelInstance:
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
file=file,
user=user,
@@ -396,7 +396,7 @@ class ModelInstance:
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
@@ -469,7 +469,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self.model_type_instance.get_tts_model_voices(
model=self.model_name, credentials=self.credentials, language=language
model=self.model, credentials=self.credentials, language=language
)

View File

@@ -39,7 +39,7 @@ class Moderation(Extensible, ABC):
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.

View File

@@ -14,7 +14,6 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
from core.ops.aliyun_trace.entities.semconv import (
DIFY_APP_ID,
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
@@ -100,16 +99,6 @@ class AliyunDataTrace(BaseTraceInstance):
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
raise ValueError(f"Aliyun get project url failed: {str(e)}")
def _extract_app_id(self, trace_info: BaseTraceInfo) -> str:
"""Extract app_id from trace_info, trying metadata first then message_data."""
app_id = trace_info.metadata.get("app_id")
if app_id:
return str(app_id)
message_data = getattr(trace_info, "message_data", None)
if message_data is not None:
return str(getattr(message_data, "app_id", ""))
return ""
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
@@ -154,16 +143,13 @@ class AliyunDataTrace(BaseTraceInstance):
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_str,
),
DIFY_APP_ID: self._extract_app_id(trace_info),
},
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_str,
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
@@ -455,8 +441,6 @@ class AliyunDataTrace(BaseTraceInstance):
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
app_id = self._extract_app_id(trace_info)
if message_span_id:
message_span = SpanData(
trace_id=trace_metadata.trace_id,
@@ -465,16 +449,13 @@ class AliyunDataTrace(BaseTraceInstance):
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
outputs=outputs_json,
),
DIFY_APP_ID: app_id,
},
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
outputs=outputs_json,
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
@@ -488,16 +469,13 @@ class AliyunDataTrace(BaseTraceInstance):
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_json,
),
**({DIFY_APP_ID: app_id} if message_span_id is None else {}),
},
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_json,
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,

View File

@@ -3,9 +3,6 @@ from typing import Final
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
# Dify-specific attributes
DIFY_APP_ID: Final[str] = "dify.app_id"
# Public attributes
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"

View File

@@ -129,11 +129,11 @@ class LangfuseSpan(BaseModel):
default=None,
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
)
start_time: datetime | None = Field(
start_time: datetime | str | None = Field(
default_factory=datetime.now,
description="The time at which the span started, defaults to the current time.",
)
end_time: datetime | None = Field(
end_time: datetime | str | None = Field(
default=None,
description="The time at which the span ended. Automatically set by span.end().",
)
@@ -146,7 +146,7 @@ class LangfuseSpan(BaseModel):
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
"via the API.",
)
level: LevelEnum | None = Field(
level: str | None = Field(
default=None,
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
"traces with elevated error levels and for highlighting in the UI.",
@@ -222,16 +222,16 @@ class LangfuseGeneration(BaseModel):
default=None,
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
)
start_time: datetime | None = Field(
start_time: datetime | str | None = Field(
default_factory=datetime.now,
description="The time at which the generation started, defaults to the current time.",
)
completion_start_time: datetime | None = Field(
completion_start_time: datetime | str | None = Field(
default=None,
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
"down into time until completion started and completion duration.",
)
end_time: datetime | None = Field(
end_time: datetime | str | None = Field(
default=None,
description="The time at which the generation ended. Automatically set by generation.end().",
)

View File

@@ -18,7 +18,8 @@ except ImportError:
from importlib_metadata import version # type: ignore[import-not-found]
if TYPE_CHECKING:
from opentelemetry.metrics import Histogram, Meter
from opentelemetry.metrics import Meter
from opentelemetry.metrics._internal.instrument import Histogram
from opentelemetry.sdk.metrics.export import MetricReader
from opentelemetry import trace as trace_api

View File

@@ -3,8 +3,6 @@ from typing import Any
from pydantic import BaseModel
from configs import dify_config
# from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
@@ -124,7 +122,7 @@ class PluginToolManager(BasePluginClient):
},
)
return merge_blob_chunks(response, max_file_size=dify_config.PLUGIN_MAX_FILE_SIZE)
return merge_blob_chunks(response)
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]

View File

@@ -1,7 +1,7 @@
from typing import Any
from core.file.models import File
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.file.models import File
def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]:

View File

@@ -2,7 +2,10 @@ from collections.abc import Mapping, Sequence
from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import file_manager
from core.file.models import File
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@@ -12,12 +15,9 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool

View File

@@ -3,13 +3,13 @@ from typing import cast
from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.prompt_transform import PromptTransform
@@ -47,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_config.model,
self.model_config.credentials,
self.history_messages,
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
)
if curr_message_tokens <= max_token_limit:
return self.history_messages
@@ -65,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform):
# a message is start with UserPromptMessage
if isinstance(prompt_message, UserPromptMessage):
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_config.model,
self.model_config.credentials,
prompt_messages,
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
)
# if current message token is overflow, drop all the prompts in current message and break
if curr_message_tokens > max_token_limit:

View File

@@ -1,10 +1,10 @@
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig

View File

@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any, cast
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
@@ -14,15 +16,13 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from models.model import AppMode
if TYPE_CHECKING:
from core.workflow.file.models import File
from core.file.models import File
class ModelMode(StrEnum):

View File

@@ -1,13 +1,12 @@
import json
import logging
import re
from typing import Any, Literal
import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
from sqlalchemy import JSON, Column, String
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -20,14 +19,10 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
_VALID_TABLE_NAME_RE = re.compile(r"^[a-zA-Z0-9_]+$")
_DISTANCE_FUNC_MAP = {
"l2": l2_distance,
"cosine": cosine_distance,
"inner_product": inner_product,
}
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
class OceanBaseVectorConfig(BaseModel):
@@ -37,14 +32,6 @@ class OceanBaseVectorConfig(BaseModel):
password: str
database: str
enable_hybrid_search: bool = False
batch_size: int = 100
metric_type: Literal["l2", "cosine", "inner_product"] = "l2"
hnsw_m: int = 16
hnsw_ef_construction: int = 256
hnsw_ef_search: int = -1
pool_size: int = 5
max_overflow: int = 10
hnsw_refresh_threshold: int = 1000
@model_validator(mode="before")
@classmethod
@@ -62,23 +49,14 @@ class OceanBaseVectorConfig(BaseModel):
class OceanBaseVector(BaseVector):
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
if not _VALID_TABLE_NAME_RE.match(collection_name):
raise ValueError(
f"Invalid collection name '{collection_name}': "
"only alphanumeric characters and underscores are allowed."
)
super().__init__(collection_name)
self._config = config
self._hnsw_ef_search = self._config.hnsw_ef_search
self._hnsw_ef_search = -1
self._client = ObVecClient(
uri=f"{self._config.host}:{self._config.port}",
user=self._config.user,
password=self._config.password,
db_name=self._config.database,
pool_size=self._config.pool_size,
max_overflow=self._config.max_overflow,
pool_recycle=3600,
pool_pre_ping=True,
)
self._fields: list[str] = [] # List of fields in the collection
if self._client.check_table_exists(collection_name):
@@ -158,8 +136,8 @@ class OceanBaseVector(BaseVector):
field_name="vector",
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
index_name="vector_index",
metric_type=self._config.metric_type,
params={"M": self._config.hnsw_m, "efConstruction": self._config.hnsw_ef_construction},
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
)
self._client.create_table_with_index_params(
@@ -200,17 +178,6 @@ class OceanBaseVector(BaseVector):
else:
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
try:
self._client.perform_raw_text_sql(
f"CREATE INDEX IF NOT EXISTS idx_metadata_doc_id ON `{self._collection_name}` "
f"((CAST(metadata->>'$.document_id' AS CHAR(64))))"
)
except SQLAlchemyError:
logger.warning(
"Failed to create metadata functional index on '%s'; metadata queries may be slow without it.",
self._collection_name,
)
self._client.refresh_metadata([self._collection_name])
self._load_collection_fields()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@@ -238,49 +205,24 @@ class OceanBaseVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
batch_size = self._config.batch_size
total = len(documents)
all_data = [
{
"id": doc_id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
}
for doc_id, doc, emb in zip(ids, documents, embeddings)
]
for start in range(0, total, batch_size):
batch = all_data[start : start + batch_size]
for id, doc, emb in zip(ids, documents, embeddings):
try:
self._client.insert(
table_name=self._collection_name,
data=batch,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
except Exception as e:
logger.exception(
"Failed to insert batch [%d:%d] into collection '%s'",
start,
start + len(batch),
self._collection_name,
)
raise Exception(
f"Failed to insert batch [{start}:{start + len(batch)}] into collection '{self._collection_name}'"
) from e
if self._config.hnsw_refresh_threshold > 0 and total >= self._config.hnsw_refresh_threshold:
try:
self._client.refresh_index(
table_name=self._collection_name,
index_name="vector_index",
)
except SQLAlchemyError:
logger.warning(
"Failed to refresh HNSW index after inserting %d documents into '%s'",
total,
"Failed to insert document with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to insert document with id '{id}'") from e
def text_exists(self, id: str) -> bool:
try:
@@ -470,7 +412,7 @@ class OceanBaseVector(BaseVector):
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=self._get_distance_func(),
distance_func=l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=_where_clause,
@@ -482,31 +424,14 @@ class OceanBaseVector(BaseVector):
)
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
# Convert distance to score and prepare results for processing
results = []
for _text, metadata_str, distance in cur:
score = self._distance_to_score(distance)
score = 1 - distance / math.sqrt(2)
results.append((_text, metadata_str, score))
return self._process_search_results(results, score_threshold=score_threshold)
def _get_distance_func(self):
func = _DISTANCE_FUNC_MAP.get(self._config.metric_type)
if func is None:
raise ValueError(
f"Unsupported metric_type '{self._config.metric_type}'. Supported: {', '.join(_DISTANCE_FUNC_MAP)}"
)
return func
def _distance_to_score(self, distance: float) -> float:
metric = self._config.metric_type
if metric == "l2":
return 1.0 / (1.0 + distance)
elif metric == "cosine":
return 1.0 - distance
elif metric == "inner_product":
return -distance
raise ValueError(f"Unsupported metric_type '{metric}'")
def delete(self):
try:
self._client.drop_table_if_exist(self._collection_name)
@@ -539,13 +464,5 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
enable_hybrid_search=dify_config.OCEANBASE_ENABLE_HYBRID_SEARCH or False,
batch_size=dify_config.OCEANBASE_VECTOR_BATCH_SIZE,
metric_type=dify_config.OCEANBASE_VECTOR_METRIC_TYPE,
hnsw_m=dify_config.OCEANBASE_HNSW_M,
hnsw_ef_construction=dify_config.OCEANBASE_HNSW_EF_CONSTRUCTION,
hnsw_ef_search=dify_config.OCEANBASE_HNSW_EF_SEARCH,
pool_size=dify_config.OCEANBASE_VECTOR_POOL_SIZE,
max_overflow=dify_config.OCEANBASE_VECTOR_MAX_OVERFLOW,
hnsw_refresh_threshold=dify_config.OCEANBASE_HNSW_REFRESH_THRESHOLD,
),
)

View File

@@ -15,11 +15,11 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str] | None:
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
@@ -27,14 +27,14 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
@@ -46,7 +46,7 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete(self) -> None:
def delete(self):
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:

View File

@@ -35,9 +35,7 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
)
.first()
)
@@ -54,7 +52,7 @@ class CacheEmbedding(Embeddings):
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model_name, self._model_instance.credentials
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
@@ -89,7 +87,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model_name,
model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@@ -116,9 +114,7 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
)
.first()
)
@@ -135,7 +131,7 @@ class CacheEmbedding(Embeddings):
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model_name, self._model_instance.credentials
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
@@ -172,7 +168,7 @@ class CacheEmbedding(Embeddings):
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model_name,
model_name=self._model_instance.model,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@@ -194,7 +190,7 @@ class CacheEmbedding(Embeddings):
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
@@ -237,7 +233,7 @@ class CacheEmbedding(Embeddings):
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)

View File

@@ -9,6 +9,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.entities.knowledge_entities import PreviewDetail
from core.file import File, FileTransferMethod, FileType, file_manager
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -34,7 +35,6 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping

View File

@@ -4,7 +4,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.workflow.file import File
from core.file import File
class ChildDocument(BaseModel):

View File

@@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model_name,
model=self.rerank_model_instance.model,
model_type=ModelType.RERANK,
)
if not is_support_vision:

View File

@@ -23,12 +23,13 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.db.session_factory import session_factory
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.token_buffer_memory import TokenBufferMemory
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
@@ -60,7 +61,6 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.file import File, FileTransferMethod, FileType
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.repositories.rag_retrieval_protocol import (
KnowledgeRetrievalRequest,

View File

@@ -2,14 +2,14 @@ import io
from collections.abc import Generator
from typing import Any
from core.file.enums import FileType
from core.file.file_manager import download
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.workflow.file.enums import FileType
from core.workflow.file.file_manager import download
from services.model_provider_service import ModelProviderService

View File

@@ -7,13 +7,13 @@ from urllib.parse import urlencode
import httpx
from core.file.file_manager import download
from core.helper import ssrf_proxy
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.workflow.file.file_manager import download
API_TOOL_DEFAULT_TIMEOUT = (
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),

View File

@@ -12,6 +12,8 @@ from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
@@ -31,8 +33,6 @@ from core.tools.errors import (
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.file import FileType
from core.workflow.file.models import FileTransferMethod
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile

View File

@@ -243,7 +243,7 @@ class ToolFileManager:
# init tool_file_parser
from core.workflow.file.tool_file_parser import set_tool_file_manager_factory
from core.file.tool_file_parser import set_tool_file_manager_factory
def _factory() -> ToolFileManager:

View File

@@ -8,9 +8,9 @@ from uuid import UUID
import numpy as np
import pytz
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file import File, FileTransferMethod, FileType
from libs.login import current_user
from models import Account

View File

@@ -47,7 +47,7 @@ class ModelInvocationUtils:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")

View File

@@ -2,7 +2,7 @@ import re
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Any, TypedDict
from typing import Any
import httpx
from flask import request
@@ -14,12 +14,6 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParamet
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class InterfaceDict(TypedDict):
path: str
method: str
operation: dict[str, Any]
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
@@ -41,7 +35,7 @@ class ApiBasedToolSchemaParser:
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces: list[InterfaceDict] = []
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:

View File

@@ -8,6 +8,7 @@ from typing import Any, cast
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
@@ -18,7 +19,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from factories.file_factory import build_from_mapping
from models import Account, Tenant
from models.model import App, EndUser

View File

@@ -5,7 +5,7 @@ from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.file import File
from core.file import File
from .types import SegmentType

View File

@@ -4,7 +4,7 @@ from collections.abc import Mapping
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from core.workflow.file.models import File
from core.file.models import File
if TYPE_CHECKING:
pass

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Protocol
class HttpResponseProtocol(Protocol):
"""Subset of response behavior needed by workflow file helpers."""
@property
def content(self) -> bytes: ...
def raise_for_status(self) -> object: ...
class WorkflowFileRuntimeProtocol(Protocol):
"""Runtime dependencies required by ``core.workflow.file``.
Implementations are expected to be provided by integration layers (for example,
``core.app.workflow.file_runtime``) so the workflow package avoids importing
application infrastructure modules directly.
"""
@property
def files_url(self) -> str: ...
@property
def internal_files_url(self) -> str | None: ...
@property
def secret_key(self) -> str: ...
@property
def files_access_timeout(self) -> int: ...
@property
def multimodal_send_format(self) -> str: ...
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ...
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ...
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ...

View File

@@ -1,58 +0,0 @@
from __future__ import annotations
from collections.abc import Generator
from typing import NoReturn
from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
class WorkflowFileRuntimeNotConfiguredError(RuntimeError):
"""Raised when workflow file runtime dependencies were not configured."""
class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
def _raise(self) -> NoReturn:
raise WorkflowFileRuntimeNotConfiguredError(
"workflow file runtime is not configured, call set_workflow_file_runtime(...) first"
)
@property
def files_url(self) -> str:
self._raise()
@property
def internal_files_url(self) -> str | None:
self._raise()
@property
def secret_key(self) -> str:
self._raise()
@property
def files_access_timeout(self) -> int:
self._raise()
@property
def multimodal_send_format(self) -> str:
self._raise()
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
self._raise()
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
self._raise()
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._raise()
_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime()
def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None:
global _runtime
_runtime = runtime
def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol:
return _runtime

View File

@@ -1,9 +0,0 @@
from collections.abc import Callable
from typing import Any
_tool_file_manager_factory: Callable[[], Any] | None = None
def set_tool_file_manager_factory(factory: Callable[[], Any]):
global _tool_file_manager_factory
_tool_file_manager_factory = factory

View File

@@ -7,28 +7,12 @@ Each instance uses a unique key for its command queue.
"""
import json
from contextlib import AbstractContextManager
from typing import Any, Protocol, final
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
class RedisPipelineProtocol(Protocol):
"""Minimal Redis pipeline contract used by the command channel."""
def lrange(self, name: str, start: int, end: int) -> Any: ...
def delete(self, *names: str) -> Any: ...
def execute(self) -> list[Any]: ...
def rpush(self, name: str, *values: str) -> Any: ...
def expire(self, name: str, time: int) -> Any: ...
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
def get(self, name: str) -> Any: ...
class RedisClientProtocol(Protocol):
"""Redis client contract required by the command channel."""
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@final
@@ -42,7 +26,7 @@ class RedisChannel:
def __init__(
self,
redis_client: RedisClientProtocol,
redis_client: "RedisClientWrapper",
channel_key: str,
command_ttl: int = 3600,
) -> None:

View File

@@ -3,14 +3,13 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Callers must provide a Redis client dependency from outside the workflow package.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
@@ -18,6 +17,7 @@ from core.workflow.graph_engine.entities.commands import (
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -31,12 +31,8 @@ class GraphEngineManager:
by sending commands through Redis channels, without user validation.
"""
_redis_client: RedisClientProtocol
def __init__(self, redis_client: RedisClientProtocol) -> None:
self._redis_client = redis_client
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
@staticmethod
def send_stop_command(task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.
@@ -45,31 +41,34 @@ class GraphEngineManager:
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
self._send_command(task_id, abort_command)
GraphEngineManager._send_command(task_id, abort_command)
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
@staticmethod
def send_pause_command(task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
self._send_command(task_id, pause_command)
GraphEngineManager._send_command(task_id, pause_command)
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
self._send_command(task_id, update_command)
GraphEngineManager._send_command(task_id, update_command)
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(self._redis_client, channel_key)
channel = RedisChannel(redis_client, channel_key)
try:
channel.send_command(command)

Some files were not shown because too many files have changed in this diff Show More