mirror of
https://github.com/langgenius/dify.git
synced 2026-04-06 04:41:45 +08:00
Compare commits
7 Commits
move-token
...
deploy/mem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2fc4f4822 | ||
|
|
67e0eeefd2 | ||
|
|
cf58035fa2 | ||
|
|
097f10d920 | ||
|
|
c34d05141e | ||
|
|
302701b303 | ||
|
|
8080159eaf |
@@ -1,168 +0,0 @@
|
||||
---
|
||||
name: backend-code-review
|
||||
description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review.
|
||||
---
|
||||
|
||||
# Backend Code Review
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes:
|
||||
|
||||
- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes).
|
||||
- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review.
|
||||
- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`).
|
||||
|
||||
Do NOT use this skill when:
|
||||
|
||||
- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`).
|
||||
- The user is not asking for a review/analysis/improvement of backend code.
|
||||
- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`).
|
||||
|
||||
## How to use this skill
|
||||
|
||||
Follow these steps when using this skill:
|
||||
|
||||
1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the user’s input. Keep the scope tight: review only what the user provided or explicitly referenced.
|
||||
2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review.
|
||||
3. Compose the final output strictly follow the **Required Output Format**.
|
||||
|
||||
Notes when using this skill:
|
||||
- Always include actionable fixes or suggestions (including possible code snippets).
|
||||
- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can.
|
||||
|
||||
## Checklist
|
||||
|
||||
- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review
|
||||
- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review
|
||||
- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review
|
||||
- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review
|
||||
|
||||
## General Review Rules
|
||||
|
||||
### 1. Security Review
|
||||
|
||||
Check for:
|
||||
- SQL injection vulnerabilities
|
||||
- Server-Side Request Forgery (SSRF)
|
||||
- Command injection
|
||||
- Insecure deserialization
|
||||
- Hardcoded secrets/credentials
|
||||
- Improper authentication/authorization
|
||||
- Insecure direct object references
|
||||
|
||||
### 2. Performance Review
|
||||
|
||||
Check for:
|
||||
- N+1 queries
|
||||
- Missing database indexes
|
||||
- Memory leaks
|
||||
- Blocking operations in async code
|
||||
- Missing caching opportunities
|
||||
|
||||
### 3. Code Quality Review
|
||||
|
||||
Check for:
|
||||
- Code forward compatibility
|
||||
- Code duplication (DRY violations)
|
||||
- Functions doing too much (SRP violations)
|
||||
- Deep nesting / complex conditionals
|
||||
- Magic numbers/strings
|
||||
- Poor naming
|
||||
- Missing error handling
|
||||
- Incomplete type coverage
|
||||
|
||||
### 4. Testing Review
|
||||
|
||||
Check for:
|
||||
- Missing test coverage for new code
|
||||
- Tests that don't test behavior
|
||||
- Flaky test patterns
|
||||
- Missing edge cases
|
||||
|
||||
## Required Output Format
|
||||
|
||||
When this skill invoked, the response must exactly follow one of the two templates:
|
||||
|
||||
### Template A (any findings)
|
||||
|
||||
```markdown
|
||||
# Code Review Summary
|
||||
|
||||
Found <X> critical issues need to be fixed:
|
||||
|
||||
## 🔴 Critical (Must Fix)
|
||||
|
||||
### 1. <brief description of the issue>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<detailed explanation and references of the issue>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
1. <brief description of suggested fix>
|
||||
2. <code example> (optional, omit if not applicable)
|
||||
|
||||
---
|
||||
... (repeat for each critical issue) ...
|
||||
|
||||
Found <Y> suggestions for improvement:
|
||||
|
||||
## 🟡 Suggestions (Should Consider)
|
||||
|
||||
### 1. <brief description of the suggestion>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<detailed explanation and references of the suggestion>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
1. <brief description of suggested fix>
|
||||
2. <code example> (optional, omit if not applicable)
|
||||
|
||||
---
|
||||
... (repeat for each suggestion) ...
|
||||
|
||||
Found <Z> optional nits:
|
||||
|
||||
## 🟢 Nits (Optional)
|
||||
### 1. <brief description of the nit>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<explanation and references of the optional nit>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
- <minor suggestions>
|
||||
|
||||
---
|
||||
... (repeat for each nits) ...
|
||||
|
||||
## ✅ What's Good
|
||||
|
||||
- <Positive feedback on good patterns>
|
||||
```
|
||||
|
||||
- If there are no critical issues or suggestions or option nits or good points, just omit that section.
|
||||
- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items.
|
||||
- Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
|
||||
```markdown
|
||||
## Code Review Summary
|
||||
✅ No issues found.
|
||||
```
|
||||
@@ -1,91 +0,0 @@
|
||||
# Rule Catalog — Architecture
|
||||
|
||||
## Scope
|
||||
- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow.
|
||||
|
||||
## Rules
|
||||
|
||||
### Keep business logic out of controllers
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test.
|
||||
- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
@bp.post("/apps/<app_id>/publish")
|
||||
def publish_app(app_id: str):
|
||||
payload = request.get_json() or {}
|
||||
if payload.get("force") and current_user.role != "admin":
|
||||
raise ValueError("only admin can force publish")
|
||||
app = App.query.get(app_id)
|
||||
app.status = "published"
|
||||
db.session.commit()
|
||||
return {"result": "ok"}
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
@bp.post("/apps/<app_id>/publish")
|
||||
def publish_app(app_id: str):
|
||||
payload = PublishRequest.model_validate(request.get_json() or {})
|
||||
app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id)
|
||||
return {"result": "ok"}
|
||||
```
|
||||
|
||||
### Preserve layer dependency direction
|
||||
- Category: best practices
|
||||
- Severity: critical
|
||||
- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code.
|
||||
- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# core/policy/publish_policy.py
|
||||
from controllers.console.app import request_context
|
||||
|
||||
def can_publish() -> bool:
|
||||
return request_context.current_user.is_admin
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# core/policy/publish_policy.py
|
||||
def can_publish(role: str) -> bool:
|
||||
return role == "admin"
|
||||
|
||||
# service layer adapts web/user context to domain input
|
||||
allowed = can_publish(role=current_user.role)
|
||||
```
|
||||
|
||||
### Keep libs business-agnostic
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions.
|
||||
- Suggested fix:
|
||||
- If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers.
|
||||
- Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# api/libs/conversation_filter.py
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
def should_archive_conversation(conversation, tenant_id: str) -> bool:
|
||||
# Domain policy and service dependency are leaking into libs.
|
||||
service = ConversationService()
|
||||
if service.has_paid_plan(tenant_id):
|
||||
return conversation.idle_days > 90
|
||||
return conversation.idle_days > 30
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# api/libs/datetime_utils.py (business-agnostic helper)
|
||||
def older_than_days(idle_days: int, threshold_days: int) -> bool:
|
||||
return idle_days > threshold_days
|
||||
|
||||
# services/conversation_service.py (business logic stays in service/core)
|
||||
from libs.datetime_utils import older_than_days
|
||||
|
||||
def should_archive_conversation(conversation, tenant_id: str) -> bool:
|
||||
threshold_days = 90 if has_paid_plan(tenant_id) else 30
|
||||
return older_than_days(conversation.idle_days, threshold_days)
|
||||
```
|
||||
@@ -1,157 +0,0 @@
|
||||
# Rule Catalog — DB Schema Design
|
||||
|
||||
## Scope
|
||||
- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations.
|
||||
- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Do not query other tables inside `@property`
|
||||
- Category: [maintainability, performance]
|
||||
- Severity: critical
|
||||
- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections.
|
||||
- Suggested fix:
|
||||
- Keep model properties pure and local to already-loaded fields.
|
||||
- Move cross-table data fetching to service/repository methods.
|
||||
- For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
class Conversation(TypeBase):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
@property
|
||||
def app_name(self) -> str:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app = session.execute(select(App).where(App.id == self.app_id)).scalar_one()
|
||||
return app.name
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
class Conversation(TypeBase):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
@property
|
||||
def display_title(self) -> str:
|
||||
return self.name or "Untitled"
|
||||
|
||||
|
||||
# Service/repository layer performs explicit batch fetch for related App rows.
|
||||
```
|
||||
|
||||
### Prefer including `tenant_id` in model definitions
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows.
|
||||
- Suggested fix:
|
||||
- Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable.
|
||||
- Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware.
|
||||
- Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class Dataset(TypeBase):
|
||||
__tablename__ = "datasets"
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class Dataset(TypeBase):
|
||||
__tablename__ = "datasets"
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
```
|
||||
|
||||
### Detect and avoid duplicate/redundant indexes
|
||||
- Category: performance
|
||||
- Severity: suggestion
|
||||
- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans.
|
||||
- Suggested fix:
|
||||
- Before adding an index, compare against existing composite indexes by leftmost-prefix rules.
|
||||
- Drop or avoid creating redundant prefixes unless there is a proven query-pattern need.
|
||||
- Apply the same review standard in both model `__table_args__` and migration index DDL.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
__table_args__ = (
|
||||
sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"),
|
||||
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
|
||||
)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
__table_args__ = (
|
||||
# Keep the wider index unless profiling proves a dedicated short index is needed.
|
||||
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
|
||||
)
|
||||
```
|
||||
|
||||
### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types`
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code.
|
||||
- Suggested fix:
|
||||
- Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used.
|
||||
- Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class ToolConfig(TypeBase):
|
||||
__tablename__ = "tool_configs"
|
||||
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from models.types import AdjustedJSON
|
||||
|
||||
class ToolConfig(TypeBase):
|
||||
__tablename__ = "tool_configs"
|
||||
config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False)
|
||||
```
|
||||
|
||||
### Guard migration incompatibilities with dialect checks and shared types
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable.
|
||||
- Suggested fix:
|
||||
- In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments.
|
||||
- Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models.
|
||||
- Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"data_source_type",
|
||||
sa.String(255),
|
||||
server_default=sa.text("'database'::character varying"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
def _is_pg(conn) -> bool:
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
conn = op.get_bind()
|
||||
default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'")
|
||||
|
||||
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False)
|
||||
)
|
||||
```
|
||||
@@ -1,61 +0,0 @@
|
||||
# Rule Catalog - Repositories Abstraction
|
||||
|
||||
## Scope
|
||||
- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations.
|
||||
- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Introduce repositories abstraction
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation.
|
||||
- Suggested fix:
|
||||
- First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access.
|
||||
- If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation).
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# Existing repository is ignored and service uses ad-hoc table queries.
|
||||
class AppService:
|
||||
def archive_app(self, app_id: str, tenant_id: str) -> None:
|
||||
app = self.session.execute(
|
||||
select(App).where(App.id == app_id, App.tenant_id == tenant_id)
|
||||
).scalar_one()
|
||||
app.archived = True
|
||||
self.session.commit()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Case A: Existing repository must be reused for all table operations.
|
||||
class AppService:
|
||||
def archive_app(self, app_id: str, tenant_id: str) -> None:
|
||||
app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id)
|
||||
app.archived = True
|
||||
self.app_repo.save(app)
|
||||
|
||||
# If the query is missing, extend the existing abstraction.
|
||||
active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id)
|
||||
```
|
||||
- Bad:
|
||||
```python
|
||||
# No repository exists, but large-domain query logic is scattered in service code.
|
||||
class ConversationService:
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
|
||||
...
|
||||
# many filters/joins/pagination variants duplicated across services
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Case B: Introduce repository for large/complex domains or storage variation.
|
||||
class ConversationRepository(Protocol):
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ...
|
||||
|
||||
class SqlAlchemyConversationRepository:
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
|
||||
...
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self, conversation_repo: ConversationRepository):
|
||||
self.conversation_repo = conversation_repo
|
||||
```
|
||||
@@ -1,139 +0,0 @@
|
||||
# Rule Catalog — SQLAlchemy Patterns
|
||||
|
||||
## Scope
|
||||
- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards.
|
||||
- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Use Session context manager with explicit transaction control behavior
|
||||
- Category: best practices
|
||||
- Severity: critical
|
||||
- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk.
|
||||
- Suggested fix:
|
||||
- Use **explicit `session.commit()`** after completing a related write unit.
|
||||
- Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block.
|
||||
- Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# Missing commit: write may never be persisted.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
|
||||
# Long transaction: external I/O inside a DB transaction.
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
call_external_api()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Option 1: explicit commit.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
session.commit()
|
||||
|
||||
# Option 2: scoped transaction with automatic commit/rollback.
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
|
||||
# Keep non-DB work outside transaction scope.
|
||||
call_external_api()
|
||||
```
|
||||
|
||||
### Enforce tenant_id scoping on shared-resource queries
|
||||
- Category: security
|
||||
- Severity: critical
|
||||
- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption.
|
||||
- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Prefer SQLAlchemy expressions over raw SQL by default
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase.
|
||||
- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
row = session.execute(
|
||||
text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"),
|
||||
{"id": workflow_id, "tenant_id": tenant_id},
|
||||
).first()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
row = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Protect write paths with concurrency safeguards
|
||||
- Category: quality
|
||||
- Severity: critical
|
||||
- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy.
|
||||
- Suggested fix:
|
||||
- **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict.
|
||||
- **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion.
|
||||
- **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk.
|
||||
- In all cases, scope by `tenant_id` and verify affected row counts for conditional writes.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# No tenant scope, no conflict detection, and no lock on a contested write path.
|
||||
session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled"))
|
||||
session.commit() # silently overwrites concurrent updates
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# 1) Optimistic lock (low contention, retry on conflict)
|
||||
result = session.execute(
|
||||
update(WorkflowRun)
|
||||
.where(
|
||||
WorkflowRun.id == run_id,
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.version == expected_version,
|
||||
)
|
||||
.values(status="cancelled", version=WorkflowRun.version + 1)
|
||||
)
|
||||
if result.rowcount == 0:
|
||||
raise WorkflowStateConflictError("stale version, retry")
|
||||
|
||||
# 2) Redis distributed lock (cross-worker critical section)
|
||||
lock_name = f"workflow_run_lock:{tenant_id}:{run_id}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
session.execute(
|
||||
update(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
|
||||
.values(status="cancelled")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention)
|
||||
run = session.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
run.status = "cancelled"
|
||||
session.commit()
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/backend-code-review
|
||||
23
.github/dependabot.yml
vendored
23
.github/dependabot.yml
vendored
@@ -1,25 +1,12 @@
|
||||
version: 2
|
||||
|
||||
multi-ecosystem-groups:
|
||||
python:
|
||||
schedule:
|
||||
interval: "weekly" # or whatever schedule you want
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
patterns: ["*"]
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
patterns: ["*"]
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
|
||||
88
.github/workflows/pyrefly-diff-comment.yml
vendored
88
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -1,88 +0,0 @@
|
||||
name: Comment with Pyrefly Diff
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- Pyrefly Diff Check
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }},
|
||||
});
|
||||
const match = artifacts.data.artifacts.find((artifact) =>
|
||||
artifact.name === 'pyrefly_diff'
|
||||
);
|
||||
if (!match) {
|
||||
throw new Error('pyrefly_diff artifact not found');
|
||||
}
|
||||
const download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: match.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
fs.writeFileSync('pyrefly_diff.zip', Buffer.from(download.data));
|
||||
|
||||
- name: Unzip artifact
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
|
||||
let prNumber = null;
|
||||
try {
|
||||
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
|
||||
} catch (err) {
|
||||
// Fallback to workflow_run payload if artifact is missing or incomplete.
|
||||
const prs = context.payload.workflow_run.pull_requests || [];
|
||||
if (prs.length > 0 && prs[0].number) {
|
||||
prNumber = prs[0].number;
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
throw new Error('PR number not found in artifact or workflow_run payload');
|
||||
}
|
||||
|
||||
const MAX_CHARS = 65000;
|
||||
if (diff.length > MAX_CHARS) {
|
||||
diff = diff.slice(0, MAX_CHARS);
|
||||
diff = diff.slice(0, diff.lastIndexOf('\\n'));
|
||||
diff += '\\n\\n... (truncated) ...';
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
94
.github/workflows/pyrefly-diff.yml
vendored
94
.github/workflows/pyrefly-diff.yml
vendored
@@ -1,94 +0,0 @@
|
||||
name: Pyrefly Diff Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'api/**/*.py'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run pyrefly on PR branch
|
||||
run: |
|
||||
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly on base branch
|
||||
run: |
|
||||
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
|
||||
|
||||
- name: Compute diff
|
||||
run: |
|
||||
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
pyrefly_diff.txt
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
|
||||
const MAX_CHARS = 65000;
|
||||
if (diff.length > MAX_CHARS) {
|
||||
diff = diff.slice(0, MAX_CHARS);
|
||||
diff = diff.slice(0, diff.lastIndexOf('\n'));
|
||||
diff += '\n\n... (truncated) ...';
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? [
|
||||
'### Pyrefly Diff',
|
||||
'<details>',
|
||||
'<summary>base → PR</summary>',
|
||||
'',
|
||||
'```diff',
|
||||
diff,
|
||||
'```',
|
||||
'</details>',
|
||||
].join('\n')
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
63
.github/workflows/web-tests.yml
vendored
63
.github/workflows/web-tests.yml
vendored
@@ -3,22 +3,14 @@ name: Web Tests
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
name: Web Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -47,58 +39,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
include-hidden-files: true
|
||||
retention-days: 1
|
||||
|
||||
merge-reports:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
run: pnpm test:ci
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
|
||||
@@ -553,8 +553,6 @@ WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
|
||||
# Comma-separated list of workflow IDs to clean logs for
|
||||
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS=
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
@@ -50,11 +50,14 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
@@ -88,6 +91,7 @@ forbidden_modules =
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
@@ -101,18 +105,29 @@ forbidden_modules =
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> configs
|
||||
core.workflow.nodes.document_extractor.node -> core.file.file_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.entities -> configs
|
||||
core.workflow.nodes.http_request.executor -> configs
|
||||
core.workflow.nodes.http_request.executor -> core.file.file_manager
|
||||
core.workflow.nodes.http_request.node -> configs
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
@@ -142,18 +157,47 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.node_events.node -> core.file
|
||||
core.workflow.nodes.agent.agent_node -> core.file
|
||||
core.workflow.nodes.datasource.datasource_node -> core.file
|
||||
core.workflow.nodes.datasource.datasource_node -> core.file.enums
|
||||
core.workflow.nodes.document_extractor.node -> core.file
|
||||
core.workflow.nodes.http_request.executor -> core.file.enums
|
||||
core.workflow.nodes.http_request.node -> core.file
|
||||
core.workflow.nodes.http_request.node -> core.file.file_manager
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
|
||||
core.workflow.nodes.list_operator.node -> core.file
|
||||
core.workflow.nodes.llm.file_saver -> core.file
|
||||
core.workflow.nodes.llm.llm_utils -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.file
|
||||
core.workflow.nodes.llm.node -> core.file.file_manager
|
||||
core.workflow.nodes.llm.node -> core.file.models
|
||||
core.workflow.nodes.loop.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
|
||||
core.workflow.nodes.protocols -> core.file
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
|
||||
core.workflow.nodes.tool.tool_node -> core.file
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.tool.tool_node -> models
|
||||
core.workflow.nodes.trigger_webhook.node -> core.file
|
||||
core.workflow.runtime.variable_pool -> core.file
|
||||
core.workflow.runtime.variable_pool -> core.file.file_manager
|
||||
core.workflow.system_variable -> core.file.models
|
||||
core.workflow.utils.condition.processor -> core.file
|
||||
core.workflow.utils.condition.processor -> core.file.file_manager
|
||||
core.workflow.workflow_entry -> core.file.models
|
||||
core.workflow.workflow_type_encoder -> core.file.models
|
||||
core.workflow.nodes.agent.agent_node -> models.model
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
@@ -190,6 +234,7 @@ ignore_imports =
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
@@ -231,7 +276,9 @@ ignore_imports =
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
@@ -247,11 +294,6 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
core.workflow.nodes.tool.tool_node -> services
|
||||
core.workflow.nodes.agent.agent_node -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.llm.node -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.token_buffer_memory
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.token_buffer_memory
|
||||
|
||||
[importlinter:contract:model-runtime-no-internal-imports]
|
||||
name = Model Runtime Internal Imports
|
||||
@@ -304,13 +346,6 @@ ignore_imports =
|
||||
core.model_runtime.model_providers.model_provider_factory -> configs
|
||||
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
core.model_runtime.token_buffer_memory -> core.app.app_config.features.file_upload.manager
|
||||
core.model_runtime.token_buffer_memory -> core.model_manager
|
||||
core.model_runtime.token_buffer_memory -> core.prompt.utils.extract_thread_messages
|
||||
core.model_runtime.token_buffer_memory -> core.workflow.file.file_manager
|
||||
core.model_runtime.token_buffer_memory -> extensions.ext_database
|
||||
core.model_runtime.token_buffer_memory -> models.model
|
||||
core.model_runtime.token_buffer_memory -> models.workflow
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
|
||||
@@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
@@ -54,6 +54,86 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
# Or on macOS
|
||||
brew install uv
|
||||
```
|
||||
|
||||
1. Install API dependencies.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Install web dependencies.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
||||
@@ -30,7 +30,6 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@@ -55,8 +54,6 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@@ -730,15 +727,8 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
@click.command("upgrade-db", help="Upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo("Preparing database migration...")
|
||||
lock = DbMigrationAutoRenewLock(
|
||||
redis_client=redis_client,
|
||||
name="db_upgrade_lock",
|
||||
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||
logger=logger,
|
||||
log_context="db_migration",
|
||||
)
|
||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||
if lock.acquire(blocking=False):
|
||||
migration_succeeded = False
|
||||
try:
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
|
||||
@@ -747,7 +737,6 @@ def upgrade_db():
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
migration_succeeded = True
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
@@ -755,8 +744,7 @@ def upgrade_db():
|
||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
lock.release_safely(status=status)
|
||||
lock.release()
|
||||
else:
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
@@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size (bytes) for plugin-generated files",
|
||||
default=50 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
@@ -1319,9 +1314,6 @@ class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
)
|
||||
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS: str = Field(
|
||||
default="", description="Comma-separated list of workflow IDs to clean logs for"
|
||||
)
|
||||
|
||||
|
||||
class SwaggerUIConfig(BaseSettings):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -51,43 +49,3 @@ class OceanBaseVectorConfig(BaseSettings):
|
||||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of documents to insert per batch",
|
||||
default=100,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_METRIC_TYPE: Literal["l2", "cosine", "inner_product"] = Field(
|
||||
description="Distance metric type for vector index: l2, cosine, or inner_product",
|
||||
default="l2",
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_M: PositiveInt = Field(
|
||||
description="HNSW M parameter (max number of connections per node)",
|
||||
default=16,
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_EF_CONSTRUCTION: PositiveInt = Field(
|
||||
description="HNSW efConstruction parameter (index build-time search width)",
|
||||
default=256,
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_EF_SEARCH: int = Field(
|
||||
description="HNSW efSearch parameter (query-time search width, -1 uses server default)",
|
||||
default=-1,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_POOL_SIZE: PositiveInt = Field(
|
||||
description="SQLAlchemy connection pool size",
|
||||
default=5,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_MAX_OVERFLOW: int = Field(
|
||||
description="SQLAlchemy connection pool max overflow connections",
|
||||
default=10,
|
||||
)
|
||||
|
||||
OCEANBASE_HNSW_REFRESH_THRESHOLD: int = Field(
|
||||
description="Minimum number of inserted documents to trigger an automatic HNSW index refresh (0 to disable)",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
|
||||
@@ -23,10 +23,10 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
@@ -660,19 +660,6 @@ class AppCopyApi(Resource):
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
try:
|
||||
# Get the original app's access mode
|
||||
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
|
||||
access_mode = original_settings.access_mode
|
||||
except Exception:
|
||||
# If original app has no settings (old app), default to public to match fallback behavior
|
||||
access_mode = "public"
|
||||
|
||||
# Apply the same access mode to the copied app
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
|
||||
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
@@ -30,10 +31,8 @@ from core.trigger.debug.event_selectors import (
|
||||
select_trigger_debug_events,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@@ -741,7 +740,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
@@ -112,11 +112,11 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = {
|
||||
**_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
"value": fields.Raw(attribute=_serialize_var_value),
|
||||
"full_content": fields.Raw(attribute=_serialize_full_content),
|
||||
}
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
full_content=fields.Raw(attribute=_serialize_full_content),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
|
||||
@@ -10,7 +10,7 @@ import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@@ -44,7 +44,6 @@ from core.errors.error import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.app_fields import (
|
||||
app_detail_fields_with_site,
|
||||
deleted_tool_fields,
|
||||
@@ -226,7 +225,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -470,7 +469,7 @@ class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
@@ -492,7 +491,7 @@ class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
@@ -521,7 +520,7 @@ class TrialAppParameterApi(Resource):
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
@@ -534,7 +533,7 @@ class AppApi(Resource):
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
@@ -553,7 +552,7 @@ class AppWorkflowApi(Resource):
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial(None)
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
@@ -571,31 +570,27 @@ class DatasetListApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
console_ns.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
console_ns.add_resource(
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
console_ns.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
console_ns.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
console_ns.add_resource(
|
||||
TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion"
|
||||
)
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
console_ns.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
console_ns.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
console_ns.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
console_ns.add_resource(
|
||||
TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run"
|
||||
)
|
||||
console_ns.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
|
||||
console_ns.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
console_ns.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
|
||||
@@ -23,7 +23,6 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
@@ -101,6 +100,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -105,9 +105,9 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[P, R]):
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
@@ -116,9 +116,9 @@ def trial_feature_enable(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[P, R]):
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
|
||||
@@ -12,8 +12,8 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
||||
@@ -137,7 +137,7 @@ class FilePreviewApi(Resource):
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
||||
enforce_download_for_html(
|
||||
response,
|
||||
|
||||
@@ -64,10 +64,6 @@ class ToolFileApi(Resource):
|
||||
|
||||
if not stream or not tool_file:
|
||||
raise NotFound("file is not found")
|
||||
|
||||
except NotFound:
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file.helpers import verify_plugin_file_signature
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
from ..common.errors import (
|
||||
|
||||
@@ -4,6 +4,7 @@ from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.file.helpers import get_signed_file_url_for_plugin
|
||||
from libs.helper import length_prefixed_response
|
||||
from models import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
@@ -31,7 +31,6 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
@@ -281,7 +280,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
@@ -18,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Pipeline
|
||||
from models.dataset import Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
@@ -66,12 +65,6 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
@@ -111,12 +104,6 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
@@ -174,12 +161,6 @@ class PipelineRunApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
||||
@@ -10,8 +10,8 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from services.file_service import FileService
|
||||
|
||||
@@ -24,7 +24,6 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@@ -122,6 +121,6 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -17,6 +17,8 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -31,7 +33,6 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
@@ -39,7 +40,6 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.workflow.file import file_manager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.enums import CreatorUserRole
|
||||
@@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
|
||||
@@ -245,7 +245,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model_name,
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
@@ -268,7 +268,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model_name,
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -10,7 +11,6 @@ from core.model_runtime.entities import (
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.file import file_manager
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Union
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
@@ -24,7 +25,6 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.file import file_manager
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
@@ -178,7 +178,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model_name,
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
@@ -308,7 +308,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model_name,
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Any, Literal
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.workflow.file import FileUploadConfig
|
||||
from core.file import FileUploadConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
|
||||
@@ -669,14 +669,16 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle retriever resources events."""
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
yield from ()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
def _handle_annotation_reply_event(
|
||||
self, event: QueueAnnotationReplyEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle annotation reply events."""
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
yield from ()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
def _handle_message_replace_event(
|
||||
self, event: QueueMessageReplaceEvent, **kwargs
|
||||
|
||||
@@ -12,11 +12,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
@@ -178,7 +178,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
@@ -31,7 +31,7 @@ class PublishFrom(IntEnum):
|
||||
TASK_PIPELINE = auto()
|
||||
|
||||
|
||||
class AppQueueManager(ABC):
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
@@ -122,7 +122,7 @@ class AppQueueManager(ABC):
|
||||
"""Attach the live graph runtime state reference for downstream consumers."""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
@@ -133,7 +133,7 @@ class AppQueueManager(ABC):
|
||||
self._publish(event, pub_from)
|
||||
|
||||
@abstractmethod
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
|
||||
@@ -22,6 +22,8 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -32,19 +34,17 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.file.models import File
|
||||
from core.file.models import File
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file import File
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from core.app.entities.task_entities import (
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@@ -59,7 +60,6 @@ from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
@@ -10,11 +10,11 @@ from core.app.entities.app_invoke_entities import (
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file import File
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Message
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""LLM-related application services."""
|
||||
@@ -1,103 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
|
||||
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_manager = provider_manager or ProviderManager()
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider_name} does not exist.")
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
class DifyModelFactory:
|
||||
tenant_id: str
|
||||
model_manager: ModelManager
|
||||
|
||||
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
return self.model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
|
||||
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
|
||||
return (
|
||||
DifyCredentialsProvider(tenant_id=tenant_id),
|
||||
DifyModelFactory(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
@@ -45,6 +45,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -57,8 +59,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.workflow.file.enums import FileTransferMethod
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -157,7 +157,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@@ -170,7 +170,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@@ -283,7 +283,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
self._task_state.llm_result.message.get_text_content()
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@@ -397,7 +397,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from core.workflow.file.runtime import set_workflow_file_runtime
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
"""Production runtime wiring for ``core.workflow.file``."""
|
||||
|
||||
@property
|
||||
def files_url(self) -> str:
|
||||
return dify_config.FILES_URL
|
||||
|
||||
@property
|
||||
def internal_files_url(self) -> str | None:
|
||||
return dify_config.INTERNAL_FILES_URL
|
||||
|
||||
@property
|
||||
def secret_key(self) -> str:
|
||||
return dify_config.SECRET_KEY
|
||||
|
||||
@property
|
||||
def files_access_timeout(self) -> int:
|
||||
return dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str:
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
|
||||
def bind_dify_workflow_file_runtime() -> None:
|
||||
set_workflow_file_runtime(DifyWorkflowFileRuntime())
|
||||
@@ -1,34 +1,28 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
|
||||
@@ -37,24 +31,6 @@ if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class DefaultWorkflowCodeExecutor:
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
language: CodeLanguage,
|
||||
code: str,
|
||||
inputs: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
return CodeExecutor.execute_workflow_code_template(
|
||||
language=language,
|
||||
code=code,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
def is_execution_error(self, error: Exception) -> bool:
|
||||
return isinstance(error, CodeExecutionError)
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
@@ -68,12 +44,23 @@ class DifyNodeFactory(NodeFactory):
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
template_transform_max_output_length: int | None = None,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers()
|
||||
self._code_limits = CodeNodeLimits(
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
|
||||
)
|
||||
self._code_limits = code_limits or CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
@@ -83,27 +70,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
self._http_request_file_manager = file_manager
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = (
|
||||
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
)
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
)
|
||||
self._http_request_config = build_http_request_config(
|
||||
max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
max_write_timeout=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
max_binary_size=dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE,
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -164,31 +138,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_request_config=self._http_request_config,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
@@ -198,35 +152,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
return DocumentExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -213,6 +213,6 @@ class DatasourceFileManager:
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
# from core.workflow.file.datasource_file_parser import datasource_file_manager
|
||||
# from core.file.datasource_file_parser import datasource_file_manager
|
||||
#
|
||||
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||
|
||||
@@ -1,39 +1,16 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from threading import Lock
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,238 +103,3 @@ class DatasourceManager:
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
||||
@classmethod
|
||||
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
|
||||
datasource_runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
return datasource_runtime.get_icon_url(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def stream_online_results(
|
||||
cls,
|
||||
*,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[DatasourceMessage, None, Any]:
|
||||
"""
|
||||
Pull-based streaming of domain messages from datasource plugins.
|
||||
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
|
||||
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
|
||||
"""
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=ds_type,
|
||||
)
|
||||
|
||||
dsp_service = DatasourceProviderService()
|
||||
credentials = dsp_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
doc_runtime.runtime.credentials = credentials
|
||||
if datasource_param is None:
|
||||
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
|
||||
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
|
||||
user_id=user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_param.workspace_id,
|
||||
page_id=datasource_param.page_id,
|
||||
type=datasource_param.type,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
drive_runtime.runtime.credentials = credentials
|
||||
if online_drive_request is None:
|
||||
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
|
||||
inner_gen = drive_runtime.online_drive_download_file(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=online_drive_request.id,
|
||||
bucket=online_drive_request.bucket,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
|
||||
|
||||
# Bridge through to caller while preserving generator return contract
|
||||
yield from inner_gen
|
||||
# No structured final data here; node/adapter will assemble outputs
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(
|
||||
cls,
|
||||
*,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: Any,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
messages = cls.stream_online_results(
|
||||
user_id=user_id,
|
||||
datasource_name=datasource_name,
|
||||
datasource_type=datasource_type,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
datasource_param=datasource_param,
|
||||
online_drive_request=online_drive_request,
|
||||
)
|
||||
|
||||
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
|
||||
)
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
file_out: File | None = None
|
||||
|
||||
for message in transformed:
|
||||
mtype = message.type
|
||||
if mtype in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
wanted_ds_type = ds_type in {
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
}
|
||||
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
|
||||
url = message.message.text
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
|
||||
)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if not datasource_file:
|
||||
raise ValueError(
|
||||
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
|
||||
)
|
||||
mime_type = datasource_file.mimetype
|
||||
if datasource_file is not None:
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(mime_type),
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"url": url,
|
||||
}
|
||||
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
elif mtype == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
||||
elif mtype == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
|
||||
)
|
||||
elif mtype == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
name = message.message.variable_name
|
||||
value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
assert isinstance(value, str), "stream variable_value must be str"
|
||||
variables[name] = variables.get(name, "") + value
|
||||
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
|
||||
else:
|
||||
variables[name] = value
|
||||
elif mtype == DatasourceMessage.MessageType.FILE:
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
|
||||
f = message.meta.get("file")
|
||||
if isinstance(f, File):
|
||||
file_out = f
|
||||
else:
|
||||
pass
|
||||
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
|
||||
variable_pool.add([node_id, "file"], file_out)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={**variables},
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_out,
|
||||
"datasource_type": ds_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = (
|
||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
return file_info
|
||||
|
||||
@@ -379,11 +379,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: str = Field("", description="The name of the bucket")
|
||||
|
||||
@field_validator("bucket", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bucket(cls, v) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
return str(v)
|
||||
bucket: str | None = Field(None, description="The name of the bucket")
|
||||
|
||||
@@ -3,8 +3,8 @@ from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file import File, FileTransferMethod, FileType
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,12 +10,12 @@ from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
@@ -11,11 +11,12 @@ from core.model_runtime.entities import (
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
from .runtime import get_workflow_file_runtime
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
@@ -44,7 +45,26 @@ def to_prompt_message_content(
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
"""Convert a file to prompt message content."""
|
||||
"""
|
||||
Convert a file to prompt message content.
|
||||
|
||||
This function converts files to their appropriate prompt message content types.
|
||||
For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
|
||||
corresponding message content with proper encoding/URL.
|
||||
|
||||
For unsupported file types, instead of raising an error, it returns a
|
||||
TextPromptMessageContent with a descriptive message about the file.
|
||||
|
||||
Args:
|
||||
f: The file to convert
|
||||
image_detail_config: Optional detail configuration for image files
|
||||
|
||||
Returns:
|
||||
PromptMessageContentUnionTypes: The appropriate message content type
|
||||
|
||||
Raises:
|
||||
ValueError: If file extension or mime_type is missing
|
||||
"""
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
@@ -57,13 +77,15 @@ def to_prompt_message_content(
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
# Check if file type is supported
|
||||
if f.type not in prompt_class_map:
|
||||
# For unsupported file types, return a text description
|
||||
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
|
||||
|
||||
send_format = get_workflow_file_runtime().multimodal_send_format
|
||||
# Process supported file types
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if send_format == "base64" else "",
|
||||
"url": _to_url(f) if send_format == "url" else "",
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
@@ -74,7 +96,7 @@ def to_prompt_message_content(
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def download(f: File, /) -> bytes:
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
@@ -84,26 +106,39 @@ def download(f: File, /) -> bytes:
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(path: str, /) -> bytes:
|
||||
"""Download and return a file from storage as bytes."""
|
||||
data = get_workflow_file_runtime().storage_load(path, stream=False)
|
||||
def _download_file_content(path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
|
||||
Returns:
|
||||
bytes: The contents of the file as a bytes object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
data = storage.load(path, stream=False)
|
||||
if not isinstance(data, bytes):
|
||||
raise ValueError(f"file {path} is not a bytes object")
|
||||
return data
|
||||
|
||||
|
||||
def _get_encoded_string(f: File, /) -> str:
|
||||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
@@ -113,7 +148,8 @@ def _get_encoded_string(f: File, /) -> str:
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
data = _download_file_content(f.storage_key)
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
@@ -126,15 +162,21 @@ def _to_url(f: File, /):
|
||||
raise ValueError("Missing file related_id")
|
||||
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
raise ValueError("Missing file related_id or extension")
|
||||
return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension)
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""Adapter exposing file manager helpers behind FileManagerProtocol."""
|
||||
"""
|
||||
Adapter exposing file manager helpers behind FileManagerProtocol.
|
||||
|
||||
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
|
||||
where a protocol-typed file manager is expected.
|
||||
"""
|
||||
|
||||
def download(self, f: File, /) -> bytes:
|
||||
return download(f)
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
@@ -7,21 +5,20 @@ import os
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
from .runtime import get_workflow_file_runtime
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str:
|
||||
runtime = get_workflow_file_runtime()
|
||||
base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url)
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = runtime.secret_key.encode()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign}
|
||||
query = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign}
|
||||
if as_attachment:
|
||||
query["as_attachment"] = "true"
|
||||
query_string = urllib.parse.urlencode(query)
|
||||
@@ -30,63 +27,57 @@ def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_ex
|
||||
|
||||
|
||||
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
|
||||
runtime = get_workflow_file_runtime()
|
||||
# Plugin access should use internal URL for Docker network communication.
|
||||
base_url = runtime.internal_files_url or runtime.files_url
|
||||
# Plugin access should use internal URL for Docker network communication
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
url = f"{base_url}/files/upload/for-plugin"
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = runtime.secret_key.encode()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
|
||||
|
||||
|
||||
def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
runtime = get_workflow_file_runtime()
|
||||
return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
|
||||
def verify_plugin_file_signature(
|
||||
*, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str
|
||||
) -> bool:
|
||||
runtime = get_workflow_file_runtime()
|
||||
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
secret_key = runtime.secret_key.encode()
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= runtime.files_access_timeout
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
runtime = get_workflow_file_runtime()
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = runtime.secret_key.encode()
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= runtime.files_access_timeout
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
runtime = get_workflow_file_runtime()
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = runtime.secret_key.encode()
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= runtime.files_access_timeout
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@@ -1,26 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.tools.signature import sign_tool_file
|
||||
|
||||
from . import helpers
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import FileTransferMethod, FileType
|
||||
|
||||
|
||||
def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``."""
|
||||
return helpers.get_signed_tool_file_url(
|
||||
tool_file_id=tool_file_id,
|
||||
extension=extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
|
||||
|
||||
class ImageConfig(BaseModel):
|
||||
"""
|
||||
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
@@ -132,11 +122,7 @@ class File(BaseModel):
|
||||
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(
|
||||
tool_file_id=self.related_id,
|
||||
extension=self.extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
|
||||
return None
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
@@ -151,7 +137,7 @@ class File(BaseModel):
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self) -> File:
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if not self.remote_url:
|
||||
@@ -174,5 +160,5 @@ class File(BaseModel):
|
||||
return self._storage_key
|
||||
|
||||
@storage_key.setter
|
||||
def storage_key(self, value: str) -> None:
|
||||
def storage_key(self, value: str):
|
||||
self._storage_key = value
|
||||
12
api/core/file/tool_file_parser.py
Normal file
12
api/core/file/tool_file_parser.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||
|
||||
|
||||
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]):
|
||||
global _tool_file_manager_factory
|
||||
_tool_file_manager_factory = factory
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -15,7 +16,6 @@ from core.model_runtime.entities import (
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.workflow.file import file_manager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
@@ -35,7 +35,7 @@ class ModelInstance:
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.model = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
@@ -163,7 +163,7 @@ class ModelInstance:
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
@@ -191,7 +191,7 @@ class ModelInstance:
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
@@ -215,7 +215,7 @@ class ModelInstance:
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
@@ -243,7 +243,7 @@ class ModelInstance:
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
@@ -264,7 +264,7 @@ class ModelInstance:
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
@@ -294,7 +294,7 @@ class ModelInstance:
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
@@ -328,7 +328,7 @@ class ModelInstance:
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
@@ -352,7 +352,7 @@ class ModelInstance:
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
@@ -373,7 +373,7 @@ class ModelInstance:
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
@@ -396,7 +396,7 @@ class ModelInstance:
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
@@ -469,7 +469,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model_name, credentials=self.credentials, language=language
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class Moderation(Extensible, ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
DIFY_APP_ID,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
@@ -100,16 +99,6 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Aliyun get project url failed: {str(e)}")
|
||||
|
||||
def _extract_app_id(self, trace_info: BaseTraceInfo) -> str:
|
||||
"""Extract app_id from trace_info, trying metadata first then message_data."""
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if app_id:
|
||||
return str(app_id)
|
||||
message_data = getattr(trace_info, "message_data", None)
|
||||
if message_data is not None:
|
||||
return str(getattr(message_data, "app_id", ""))
|
||||
return ""
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
|
||||
@@ -154,16 +143,13 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
DIFY_APP_ID: self._extract_app_id(trace_info),
|
||||
},
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
@@ -455,8 +441,6 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
|
||||
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
|
||||
|
||||
app_id = self._extract_app_id(trace_info)
|
||||
|
||||
if message_span_id:
|
||||
message_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
@@ -465,16 +449,13 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
DIFY_APP_ID: app_id,
|
||||
},
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
@@ -488,16 +469,13 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
**({DIFY_APP_ID: app_id} if message_span_id is None else {}),
|
||||
},
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
|
||||
@@ -3,9 +3,6 @@ from typing import Final
|
||||
|
||||
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
|
||||
|
||||
# Dify-specific attributes
|
||||
DIFY_APP_ID: Final[str] = "dify.app_id"
|
||||
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
|
||||
@@ -129,11 +129,11 @@ class LangfuseSpan(BaseModel):
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
start_time: datetime | None = Field(
|
||||
start_time: datetime | str | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the span started, defaults to the current time.",
|
||||
)
|
||||
end_time: datetime | None = Field(
|
||||
end_time: datetime | str | None = Field(
|
||||
default=None,
|
||||
description="The time at which the span ended. Automatically set by span.end().",
|
||||
)
|
||||
@@ -146,7 +146,7 @@ class LangfuseSpan(BaseModel):
|
||||
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
level: LevelEnum | None = Field(
|
||||
level: str | None = Field(
|
||||
default=None,
|
||||
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
|
||||
"traces with elevated error levels and for highlighting in the UI.",
|
||||
@@ -222,16 +222,16 @@ class LangfuseGeneration(BaseModel):
|
||||
default=None,
|
||||
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
start_time: datetime | None = Field(
|
||||
start_time: datetime | str | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the generation started, defaults to the current time.",
|
||||
)
|
||||
completion_start_time: datetime | None = Field(
|
||||
completion_start_time: datetime | str | None = Field(
|
||||
default=None,
|
||||
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
|
||||
"down into time until completion started and completion duration.",
|
||||
)
|
||||
end_time: datetime | None = Field(
|
||||
end_time: datetime | str | None = Field(
|
||||
default=None,
|
||||
description="The time at which the generation ended. Automatically set by generation.end().",
|
||||
)
|
||||
|
||||
@@ -18,7 +18,8 @@ except ImportError:
|
||||
from importlib_metadata import version # type: ignore[import-not-found]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Histogram, Meter
|
||||
from opentelemetry.metrics import Meter
|
||||
from opentelemetry.metrics._internal.instrument import Histogram
|
||||
from opentelemetry.sdk.metrics.export import MetricReader
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
|
||||
@@ -3,8 +3,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
# from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
@@ -124,7 +122,7 @@ class PluginToolManager(BasePluginClient):
|
||||
},
|
||||
)
|
||||
|
||||
return merge_blob_chunks(response, max_file_size=dify_config.PLUGIN_MAX_FILE_SIZE)
|
||||
return merge_blob_chunks(response)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import File
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.file.models import File
|
||||
|
||||
|
||||
def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -2,7 +2,10 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -12,12 +15,9 @@ from core.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file import file_manager
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import cast
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
self.history_messages,
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
||||
)
|
||||
if curr_message_tokens <= max_token_limit:
|
||||
return self.history_messages
|
||||
@@ -65,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
# a message is start with UserPromptMessage
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
prompt_messages,
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
||||
)
|
||||
# if current message token is overflow, drop all the prompts in current message and break
|
||||
if curr_message_tokens > max_token_limit:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
@@ -14,15 +16,13 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file import file_manager
|
||||
from models.model import AppMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.file.models import File
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(StrEnum):
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore
|
||||
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
|
||||
from sqlalchemy import JSON, Column, String
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -20,14 +19,10 @@ from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
|
||||
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
|
||||
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
|
||||
_VALID_TABLE_NAME_RE = re.compile(r"^[a-zA-Z0-9_]+$")
|
||||
|
||||
_DISTANCE_FUNC_MAP = {
|
||||
"l2": l2_distance,
|
||||
"cosine": cosine_distance,
|
||||
"inner_product": inner_product,
|
||||
}
|
||||
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
|
||||
|
||||
|
||||
class OceanBaseVectorConfig(BaseModel):
|
||||
@@ -37,14 +32,6 @@ class OceanBaseVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
enable_hybrid_search: bool = False
|
||||
batch_size: int = 100
|
||||
metric_type: Literal["l2", "cosine", "inner_product"] = "l2"
|
||||
hnsw_m: int = 16
|
||||
hnsw_ef_construction: int = 256
|
||||
hnsw_ef_search: int = -1
|
||||
pool_size: int = 5
|
||||
max_overflow: int = 10
|
||||
hnsw_refresh_threshold: int = 1000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -62,23 +49,14 @@ class OceanBaseVectorConfig(BaseModel):
|
||||
|
||||
class OceanBaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
|
||||
if not _VALID_TABLE_NAME_RE.match(collection_name):
|
||||
raise ValueError(
|
||||
f"Invalid collection name '{collection_name}': "
|
||||
"only alphanumeric characters and underscores are allowed."
|
||||
)
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._hnsw_ef_search = self._config.hnsw_ef_search
|
||||
self._hnsw_ef_search = -1
|
||||
self._client = ObVecClient(
|
||||
uri=f"{self._config.host}:{self._config.port}",
|
||||
user=self._config.user,
|
||||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
pool_size=self._config.pool_size,
|
||||
max_overflow=self._config.max_overflow,
|
||||
pool_recycle=3600,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
if self._client.check_table_exists(collection_name):
|
||||
@@ -158,8 +136,8 @@ class OceanBaseVector(BaseVector):
|
||||
field_name="vector",
|
||||
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
|
||||
index_name="vector_index",
|
||||
metric_type=self._config.metric_type,
|
||||
params={"M": self._config.hnsw_m, "efConstruction": self._config.hnsw_ef_construction},
|
||||
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
|
||||
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
|
||||
)
|
||||
|
||||
self._client.create_table_with_index_params(
|
||||
@@ -200,17 +178,6 @@ class OceanBaseVector(BaseVector):
|
||||
else:
|
||||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
try:
|
||||
self._client.perform_raw_text_sql(
|
||||
f"CREATE INDEX IF NOT EXISTS idx_metadata_doc_id ON `{self._collection_name}` "
|
||||
f"((CAST(metadata->>'$.document_id' AS CHAR(64))))"
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Failed to create metadata functional index on '%s'; metadata queries may be slow without it.",
|
||||
self._collection_name,
|
||||
)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
self._load_collection_fields()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
@@ -238,49 +205,24 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
batch_size = self._config.batch_size
|
||||
total = len(documents)
|
||||
|
||||
all_data = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
}
|
||||
for doc_id, doc, emb in zip(ids, documents, embeddings)
|
||||
]
|
||||
|
||||
for start in range(0, total, batch_size):
|
||||
batch = all_data[start : start + batch_size]
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
try:
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data=batch,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to insert batch [%d:%d] into collection '%s'",
|
||||
start,
|
||||
start + len(batch),
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to insert batch [{start}:{start + len(batch)}] into collection '{self._collection_name}'"
|
||||
) from e
|
||||
|
||||
if self._config.hnsw_refresh_threshold > 0 and total >= self._config.hnsw_refresh_threshold:
|
||||
try:
|
||||
self._client.refresh_index(
|
||||
table_name=self._collection_name,
|
||||
index_name="vector_index",
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Failed to refresh HNSW index after inserting %d documents into '%s'",
|
||||
total,
|
||||
"Failed to insert document with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to insert document with id '{id}'") from e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
@@ -470,7 +412,7 @@ class OceanBaseVector(BaseVector):
|
||||
vec_column_name="vector",
|
||||
vec_data=query_vector,
|
||||
topk=topk,
|
||||
distance_func=self._get_distance_func(),
|
||||
distance_func=l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=_where_clause,
|
||||
@@ -482,31 +424,14 @@ class OceanBaseVector(BaseVector):
|
||||
)
|
||||
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
# Convert distance to score and prepare results for processing
|
||||
results = []
|
||||
for _text, metadata_str, distance in cur:
|
||||
score = self._distance_to_score(distance)
|
||||
score = 1 - distance / math.sqrt(2)
|
||||
results.append((_text, metadata_str, score))
|
||||
|
||||
return self._process_search_results(results, score_threshold=score_threshold)
|
||||
|
||||
def _get_distance_func(self):
|
||||
func = _DISTANCE_FUNC_MAP.get(self._config.metric_type)
|
||||
if func is None:
|
||||
raise ValueError(
|
||||
f"Unsupported metric_type '{self._config.metric_type}'. Supported: {', '.join(_DISTANCE_FUNC_MAP)}"
|
||||
)
|
||||
return func
|
||||
|
||||
def _distance_to_score(self, distance: float) -> float:
|
||||
metric = self._config.metric_type
|
||||
if metric == "l2":
|
||||
return 1.0 / (1.0 + distance)
|
||||
elif metric == "cosine":
|
||||
return 1.0 - distance
|
||||
elif metric == "inner_product":
|
||||
return -distance
|
||||
raise ValueError(f"Unsupported metric_type '{metric}'")
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
@@ -539,13 +464,5 @@ class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
|
||||
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
|
||||
enable_hybrid_search=dify_config.OCEANBASE_ENABLE_HYBRID_SEARCH or False,
|
||||
batch_size=dify_config.OCEANBASE_VECTOR_BATCH_SIZE,
|
||||
metric_type=dify_config.OCEANBASE_VECTOR_METRIC_TYPE,
|
||||
hnsw_m=dify_config.OCEANBASE_HNSW_M,
|
||||
hnsw_ef_construction=dify_config.OCEANBASE_HNSW_EF_CONSTRUCTION,
|
||||
hnsw_ef_search=dify_config.OCEANBASE_HNSW_EF_SEARCH,
|
||||
pool_size=dify_config.OCEANBASE_VECTOR_POOL_SIZE,
|
||||
max_overflow=dify_config.OCEANBASE_VECTOR_MAX_OVERFLOW,
|
||||
hnsw_refresh_threshold=dify_config.OCEANBASE_HNSW_REFRESH_THRESHOLD,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,11 +15,11 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str] | None:
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -27,14 +27,14 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -46,7 +46,7 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
def delete(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
|
||||
@@ -35,9 +35,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -54,7 +52,7 @@ class CacheEmbedding(Embeddings):
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
@@ -89,7 +87,7 @@ class CacheEmbedding(Embeddings):
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model_name,
|
||||
model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@@ -116,9 +114,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -135,7 +131,7 @@ class CacheEmbedding(Embeddings):
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
@@ -172,7 +168,7 @@ class CacheEmbedding(Embeddings):
|
||||
file_id = multimodel_documents[i]["file_id"]
|
||||
if file_id not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model_name,
|
||||
model_name=self._model_instance.model,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@@ -194,7 +190,7 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
@@ -237,7 +233,7 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, cast
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -34,7 +35,6 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.file import File
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ChildDocument(BaseModel):
|
||||
|
||||
@@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model=self.rerank_model_instance.model_name,
|
||||
model=self.rerank_model_instance.model,
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if not is_support_vision:
|
||||
|
||||
@@ -23,12 +23,13 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.token_buffer_memory import TokenBufferMemory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
@@ -60,7 +61,6 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.file import File, FileTransferMethod, FileType
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import (
|
||||
KnowledgeRetrievalRequest,
|
||||
|
||||
@@ -2,14 +2,14 @@ import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.workflow.file.enums import FileType
|
||||
from core.workflow.file.file_manager import download
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
|
||||
@@ -7,13 +7,13 @@ from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file.file_manager import download
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.workflow.file.file_manager import download
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (
|
||||
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
|
||||
|
||||
@@ -12,6 +12,8 @@ from yarl import URL
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
@@ -31,8 +33,6 @@ from core.tools.errors import (
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.workflow.file import FileType
|
||||
from core.workflow.file.models import FileTransferMethod
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
@@ -243,7 +243,7 @@ class ToolFileManager:
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.workflow.file.tool_file_parser import set_tool_file_manager_factory
|
||||
from core.file.tool_file_parser import set_tool_file_manager_factory
|
||||
|
||||
|
||||
def _factory() -> ToolFileManager:
|
||||
|
||||
@@ -8,9 +8,9 @@ from uuid import UUID
|
||||
import numpy as np
|
||||
import pytz
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file import File, FileTransferMethod, FileType
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class ModelInvocationUtils:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
@@ -2,7 +2,7 @@ import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
@@ -14,12 +14,6 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParamet
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class InterfaceDict(TypedDict):
|
||||
path: str
|
||||
method: str
|
||||
operation: dict[str, Any]
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
@@ -41,7 +35,7 @@ class ApiBasedToolSchemaParser:
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces: list[InterfaceDict] = []
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@@ -18,7 +19,6 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from factories.file_factory import build_from_mapping
|
||||
from models import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.workflow.file import File
|
||||
from core.file import File
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.file.models import File
|
||||
from core.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class HttpResponseProtocol(Protocol):
|
||||
"""Subset of response behavior needed by workflow file helpers."""
|
||||
|
||||
@property
|
||||
def content(self) -> bytes: ...
|
||||
|
||||
def raise_for_status(self) -> object: ...
|
||||
|
||||
|
||||
class WorkflowFileRuntimeProtocol(Protocol):
|
||||
"""Runtime dependencies required by ``core.workflow.file``.
|
||||
|
||||
Implementations are expected to be provided by integration layers (for example,
|
||||
``core.app.workflow.file_runtime``) so the workflow package avoids importing
|
||||
application infrastructure modules directly.
|
||||
"""
|
||||
|
||||
@property
|
||||
def files_url(self) -> str: ...
|
||||
|
||||
@property
|
||||
def internal_files_url(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def secret_key(self) -> str: ...
|
||||
|
||||
@property
|
||||
def files_access_timeout(self) -> int: ...
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str: ...
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ...
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ...
|
||||
|
||||
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ...
|
||||
@@ -1,58 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import NoReturn
|
||||
|
||||
from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
|
||||
|
||||
class WorkflowFileRuntimeNotConfiguredError(RuntimeError):
|
||||
"""Raised when workflow file runtime dependencies were not configured."""
|
||||
|
||||
|
||||
class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
def _raise(self) -> NoReturn:
|
||||
raise WorkflowFileRuntimeNotConfiguredError(
|
||||
"workflow file runtime is not configured, call set_workflow_file_runtime(...) first"
|
||||
)
|
||||
|
||||
@property
|
||||
def files_url(self) -> str:
|
||||
self._raise()
|
||||
|
||||
@property
|
||||
def internal_files_url(self) -> str | None:
|
||||
self._raise()
|
||||
|
||||
@property
|
||||
def secret_key(self) -> str:
|
||||
self._raise()
|
||||
|
||||
@property
|
||||
def files_access_timeout(self) -> int:
|
||||
self._raise()
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str:
|
||||
self._raise()
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
self._raise()
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
self._raise()
|
||||
|
||||
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._raise()
|
||||
|
||||
|
||||
_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime()
|
||||
|
||||
|
||||
def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None:
|
||||
global _runtime
|
||||
_runtime = runtime
|
||||
|
||||
|
||||
def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol:
|
||||
return _runtime
|
||||
@@ -1,9 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
_tool_file_manager_factory: Callable[[], Any] | None = None
|
||||
|
||||
|
||||
def set_tool_file_manager_factory(factory: Callable[[], Any]):
|
||||
global _tool_file_manager_factory
|
||||
_tool_file_manager_factory = factory
|
||||
@@ -7,28 +7,12 @@ Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Protocol, final
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
|
||||
class RedisPipelineProtocol(Protocol):
|
||||
"""Minimal Redis pipeline contract used by the command channel."""
|
||||
|
||||
def lrange(self, name: str, start: int, end: int) -> Any: ...
|
||||
def delete(self, *names: str) -> Any: ...
|
||||
def execute(self) -> list[Any]: ...
|
||||
def rpush(self, name: str, *values: str) -> Any: ...
|
||||
def expire(self, name: str, time: int) -> Any: ...
|
||||
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
|
||||
def get(self, name: str) -> Any: ...
|
||||
|
||||
|
||||
class RedisClientProtocol(Protocol):
|
||||
"""Redis client contract required by the command channel."""
|
||||
|
||||
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
@@ -42,7 +26,7 @@ class RedisChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: RedisClientProtocol,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
|
||||
@@ -3,14 +3,13 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Callers must provide a Redis client dependency from outside the workflow package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
@@ -18,6 +17,7 @@ from core.workflow.graph_engine.entities.commands import (
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,12 +31,8 @@ class GraphEngineManager:
|
||||
by sending commands through Redis channels, without user validation.
|
||||
"""
|
||||
|
||||
_redis_client: RedisClientProtocol
|
||||
|
||||
def __init__(self, redis_client: RedisClientProtocol) -> None:
|
||||
self._redis_client = redis_client
|
||||
|
||||
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
@@ -45,31 +41,34 @@ class GraphEngineManager:
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
self._send_command(task_id, abort_command)
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
|
||||
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
self._send_command(task_id, pause_command)
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
self._send_command(task_id, update_command)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
|
||||
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(self._redis_client, channel_key)
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
try:
|
||||
channel.send_command(command)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user