Compare commits

..

2 Commits

171 changed files with 2289 additions and 2781 deletions

View File

@@ -1,16 +1,15 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
corepack enable
cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc

View File

@@ -8,7 +8,8 @@ on:
- "deploy/enterprise"
- "build/**"
- "release/e-*"
- "hotfix/**"
- "deploy/rag-dev"
- "feat/rag-2"
tags:
- "*"

View File

@@ -4,7 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/dev"
- "deploy/rag-dev"
types:
- completed
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev'
github.event.workflow_run.head_branch == 'deploy/rag-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8

6
.gitignore vendored
View File

@@ -230,8 +230,4 @@ api/.env.backup
# Benchmark
scripts/stress-test/setup/config/
scripts/stress-test/reports/
# mcp
.playwright-mcp/
.serena/
scripts/stress-test/reports/

View File

@@ -61,9 +61,8 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, and import linter..."
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --directory api --dev lint-imports
@echo "🔧 Running ruff format and check with fixes..."
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@echo "✅ Linting complete"
type-check:

View File

@@ -304,8 +304,6 @@ BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url

View File

@@ -30,7 +30,6 @@ select = [
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"T201", # print-found
"TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message
"UP", # pyupgrade rules
@@ -92,18 +91,11 @@ ignore = [
"configs/*" = [
"N802", # invalid-function-name
]
"core/model_runtime/callbacks/base_callback.py" = [
"T201",
]
"core/workflow/callbacks/workflow_logging_callback.py" = [
"T201",
]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests
]
[lint.pyflakes]

View File

@@ -1,11 +1,20 @@
import logging
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore
_logger = logging.getLogger(__name__)
def _log(message: str):
print(message, flush=True)
# grpc gevent
grpc_gevent.init_gevent()
print("gRPC patched with gevent.", flush=True) # noqa: T201
_log("gRPC patched with gevent.")
pscycogreen_gevent.patch_psycopg()
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
_log("psycopg2 patched with gevent.")
from app import app, celery

View File

@@ -10,7 +10,6 @@ from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@@ -62,30 +61,31 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
account = db.session.query(Account).where(Account.email == email).one_or_none()
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@@ -100,21 +100,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
account = db.session.query(Account).where(Account.email == email).one_or_none()
try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
account.email = new_email
click.echo(click.style("Email updated successfully.", fg="green"))
try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command(
@@ -138,24 +139,25 @@ def reset_encrypt_key_pair():
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
)
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
)
)
@click.command("vdb-migrate", help="Migrate vector db.")
@@ -180,15 +182,14 @@ def migrate_annotation_vector_database():
try:
# get apps info
per_page = 50
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
apps = (
db.session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
if not apps:
break
except SQLAlchemyError:
@@ -202,27 +203,26 @@ def migrate_annotation_vector_database():
)
try:
click.echo(f"Creating app annotation index: {app.id}")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@@ -739,18 +739,18 @@ where sites.id is null limit 1000"""
try:
app = db.session.query(App).where(App.id == app_id).first()
if not app:
logger.info("App %s not found", app_id)
print(f"App {app_id} not found")
continue
tenant = app.tenant
if tenant:
accounts = tenant.get_accounts()
if not accounts:
logger.info("Fix failed for app %s", app.id)
print(f"Fix failed for app {app.id}")
continue
account = accounts[0]
logger.info("Fixing missing site for app %s", app.id)
print(f"Fixing missing site for app {app.id}")
app_was_created.send(app, account=account)
except Exception:
failed_app_ids.append(app_id)
@@ -1448,52 +1448,41 @@ def transform_datasource_credentials():
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = notion_tenant_credential.access_token
# notion info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion_datasource",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = notion_tenant_credential.access_token
# notion info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion_datasource",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
continue
db.session.add(datasource_provider)
deal_notion_count += 1
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
@@ -1506,48 +1495,37 @@ def transform_datasource_credentials():
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
continue
db.session.add(datasource_provider)
deal_firecrawl_count += 1
db.session.commit()
# deal jina credentials
deal_jina_count = 0
@@ -1560,45 +1538,36 @@ def transform_datasource_credentials():
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
print(jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
except Exception as e:
click.echo(
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
continue
db.session.add(datasource_provider)
deal_jina_count += 1
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@@ -18,18 +18,3 @@ class EnterpriseFeatureConfig(BaseSettings):
description="Allow customization of the enterprise logo.",
default=False,
)
UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN: str = Field(
description="Token for uploading knowledge pipeline template.",
default="",
)
KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT: str = Field(
description="Knowledge pipeline template copyright.",
default="Copyright 2023 Dify",
)
KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY: str = Field(
description="Knowledge pipeline template privacy policy.",
default="https://dify.ai",
)

View File

@@ -41,13 +41,3 @@ class BaiduVectorDBConfig(BaseSettings):
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
default="DEFAULT_ANALYZER",
)
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
default="COARSE_MODE",
)

View File

@@ -37,15 +37,3 @@ class OceanBaseVectorConfig(BaseSettings):
"with older versions",
default=False,
)
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
description=(
"Fulltext parser to use for text indexing. "
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
),
default="ik",
)

View File

@@ -5,7 +5,7 @@ import logging
import os
import time
import httpx
import requests
logger = logging.getLogger(__name__)
@@ -30,10 +30,10 @@ class NacosHttpClient:
params = {}
try:
self._inject_auth_info(headers, params)
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except httpx.RequestError as e:
except requests.RequestException as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
@@ -78,7 +78,7 @@ class NacosHttpClient:
params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login"
try:
resp = httpx.request("POST", url, headers=None, params=params)
resp = requests.request("POST", url, headers=None, params=params)
resp.raise_for_status()
response_data = resp.json()
self.token = response_data.get("accessToken")

View File

@@ -1,7 +1,6 @@
from datetime import datetime
import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
@@ -71,7 +70,7 @@ class CompletionConversationApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
query = sa.select(Conversation).where(
query = db.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
@@ -237,7 +236,7 @@ class ChatConversationApi(Resource):
.subquery()
)
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"

View File

@@ -62,9 +62,6 @@ class ChatMessageListApi(Resource):
@account_initialization_required
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")

View File

@@ -50,9 +50,8 @@ class DailyMessageStatistic(Resource):
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -188,9 +187,8 @@ class DailyTerminalsStatistic(Resource):
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -261,9 +259,8 @@ class DailyTokenCostStatistic(Resource):
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -343,9 +340,8 @@ FROM
messages m
ON c.id = m.conversation_id
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -430,9 +426,8 @@ LEFT JOIN
message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
m.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -507,9 +502,8 @@ class AverageResponseTimeStatistic(Resource):
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -582,9 +576,8 @@ class TokensPerSecondStatistic(Resource):
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@@ -1,6 +1,6 @@
import logging
import httpx
import requests
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields
@@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except httpx.HTTPStatusError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
@@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except httpx.HTTPStatusError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)

View File

@@ -1,6 +1,6 @@
import logging
import httpx
import requests
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
@@ -101,10 +101,8 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except httpx.RequestError as e:
error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400

View File

@@ -782,6 +782,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
@@ -808,7 +809,6 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
@@ -838,6 +838,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
@@ -862,7 +863,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [

View File

@@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@@ -212,13 +211,13 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count":
sub_query = (
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id)
.subquery()
)
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == "created_at":

View File

@@ -14,10 +14,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import (
PipelineBuiltInTemplateEntity,
PipelineTemplateInfoEntity,
)
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
@@ -29,6 +26,12 @@ def _validate_name(name):
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@@ -143,186 +146,6 @@ class PublishCustomizedPipelineTemplateApi(Resource):
return {"result": "success"}
class PipelineTemplateInstallApi(Resource):
"""API endpoint for installing built-in pipeline templates"""
def post(self):
"""
Install a built-in pipeline template
Args:
template_id: The template ID from URL parameter
Returns:
Success response or error with appropriate HTTP status
"""
try:
# Extract and validate Bearer token
auth_token = self._extract_bearer_token()
# Parse and validate request parameters
template_args = self._parse_template_args()
# Process uploaded template file
file_content = self._process_template_file()
# Create template entity
pipeline_built_in_template_entity = PipelineBuiltInTemplateEntity(**template_args)
# Install the template
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.install_built_in_pipeline_template(
pipeline_built_in_template_entity, file_content, auth_token
)
return {"result": "success", "message": "Template installed successfully"}, 200
except ValueError as e:
logger.exception("Validation error in template installation")
return {"error": str(e)}, 400
except Exception as e:
logger.exception("Unexpected error in template installation")
return {"error": "An unexpected error occurred during template installation"}, 500
def _extract_bearer_token(self) -> str:
"""
Extract and validate Bearer token from Authorization header
Returns:
The extracted token string
Raises:
ValueError: If token is missing or invalid
"""
auth_header = request.headers.get("Authorization", "").strip()
if not auth_header:
raise ValueError("Authorization header is required")
if not auth_header.startswith("Bearer "):
raise ValueError("Authorization header must start with 'Bearer '")
token_parts = auth_header.split(" ", 1)
if len(token_parts) != 2:
raise ValueError("Invalid Authorization header format")
auth_token = token_parts[1].strip()
if not auth_token:
raise ValueError("Bearer token cannot be empty")
return auth_token
def _parse_template_args(self) -> dict:
"""
Parse and validate template arguments from form data
Args:
template_id: The template ID from URL
Returns:
Dictionary of validated template arguments
"""
# Use reqparse for consistent parameter parsing
parser = reqparse.RequestParser()
parser.add_argument(
"template_id",
type=str,
location="form",
required=False,
help="Template ID for updating existing template"
)
parser.add_argument(
"language",
type=str,
location="form",
required=True,
default="en-US",
choices=["en-US", "zh-CN", "ja-JP"],
help="Template language code"
)
parser.add_argument(
"name",
type=str,
location="form",
required=True,
default="New Pipeline Template",
help="Template name (1-200 characters)"
)
parser.add_argument(
"description",
type=str,
location="form",
required=False,
default="",
help="Template description (max 1000 characters)"
)
args = parser.parse_args()
# Additional validation
if args.get("name"):
args["name"] = self._validate_name(args["name"])
if args.get("description") and len(args["description"]) > 1000:
raise ValueError("Description must not exceed 1000 characters")
# Filter out None values
return {k: v for k, v in args.items() if v is not None}
def _validate_name(self, name: str) -> str:
"""
Validate template name
Args:
name: Template name to validate
Returns:
Validated and trimmed name
Raises:
ValueError: If name is invalid
"""
name = name.strip()
if not name or len(name) < 1 or len(name) > 200:
raise ValueError("Template name must be between 1 and 200 characters")
return name
def _process_template_file(self) -> str:
"""
Process and validate uploaded template file
Returns:
File content as string
Raises:
ValueError: If file is missing or invalid
"""
if "file" not in request.files:
raise ValueError("Template file is required")
file = request.files["file"]
# Validate file
if not file or not file.filename:
raise ValueError("No file selected")
filename = file.filename.strip()
if not filename:
raise ValueError("File name cannot be empty")
# Check file extension
if not filename.lower().endswith(".pipeline"):
raise ValueError("Template file must be a pipeline file (.pipeline)")
try:
file_content = file.read().decode("utf-8")
except UnicodeDecodeError:
raise ValueError("Template file must be valid UTF-8 text")
return file_content
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
@@ -339,7 +162,3 @@ api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)
api.add_resource(
PipelineTemplateInstallApi,
"/rag/pipeline/built-in/templates/install",
)

View File

@@ -118,14 +118,12 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=str, default="false", location="args")
parser.add_argument("include_secret", type=bool, default=False, location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true"
)
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
return {"data": result}, 200

View File

@@ -1,7 +1,7 @@
import json
import logging
import httpx
import requests
from flask_restx import Resource, fields, reqparse
from packaging import version
@@ -57,11 +57,7 @@ class VersionApi(Resource):
return result
try:
response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"]

View File

@@ -79,12 +79,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs

View File

@@ -427,9 +427,6 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@@ -468,7 +465,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def single_loop_generate(
@@ -563,7 +559,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def _generate_worker(

View File

@@ -86,12 +86,29 @@ class PipelineRunner(WorkflowBasedAppRunner):
db.session.close()
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
if self.application_generate_entity.single_iteration_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs

View File

@@ -51,12 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs

View File

@@ -1,4 +1,3 @@
import time
from collections.abc import Mapping
from typing import Any, cast
@@ -120,81 +119,15 @@ class WorkflowBasedAppRunner:
return graph
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
user_inputs: dict,
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
"""
Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
@@ -212,22 +145,18 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
# filter nodes only in iteration
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
# filter edges only in iteration
edge_configs = [
edge
for edge in graph_config.get("edges", [])
@@ -261,26 +190,30 @@ class WorkflowBasedAppRunner:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
iteration_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
iteration_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
# Get node class
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=target_node_config
graph_config=workflow.graph_dict, config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
@@ -301,44 +234,120 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
user_inputs: dict,
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event

View File

@@ -0,0 +1,388 @@
import re
import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = str(uuid.uuid4())
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: str | None = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue]
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue]
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue]
# convert definitions
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType]
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@@ -0,0 +1,17 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@@ -0,0 +1,9 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@@ -0,0 +1,43 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@@ -0,0 +1,35 @@
import logging
from pathlib import Path
from typing import Any
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e

View File

@@ -205,10 +205,16 @@ class ProviderConfiguration(BaseModel):
"""
Get custom provider record.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(self._get_provider_names()),
Provider.provider_name.in_(provider_names),
)
return session.execute(stmt).scalar_one_or_none()
@@ -270,7 +276,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.credential_name == credential_name,
)
if exclude_id:
@@ -318,7 +324,7 @@ class ProviderConfiguration(BaseModel):
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
@@ -368,7 +374,7 @@ class ProviderConfiguration(BaseModel):
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
),
)
@@ -381,7 +387,7 @@ class ProviderConfiguration(BaseModel):
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
@@ -417,16 +423,6 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
@@ -505,7 +501,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
)
# Get the credential record to update
@@ -558,7 +554,7 @@ class ProviderConfiguration(BaseModel):
# Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source,
)
@@ -595,7 +591,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
)
# Get the credential record to update
@@ -606,7 +602,7 @@ class ProviderConfiguration(BaseModel):
# Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
)
@@ -628,7 +624,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@@ -672,7 +668,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.provider_name == self.provider.provider,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -741,7 +737,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -788,7 +784,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
@@ -864,7 +860,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -1001,7 +997,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -1046,7 +1042,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -1056,7 +1052,7 @@ class ProviderConfiguration(BaseModel):
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
)
@@ -1079,7 +1075,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -1119,7 +1115,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -1161,7 +1157,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@@ -1208,9 +1204,15 @@ class ProviderConfiguration(BaseModel):
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
@@ -1382,9 +1384,15 @@ class ProviderConfiguration(BaseModel):
return
def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
TenantPreferredModelProvider.provider_name.in_(provider_names),
)
preferred_model_provider = s.execute(stmt).scalars().first()

View File

@@ -8,7 +8,7 @@ from collections import deque
from collections.abc import Sequence
from datetime import datetime
import httpx
import requests
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
@@ -65,13 +65,13 @@ class TraceClient:
def api_check(self):
try:
response = httpx.head(self.endpoint, timeout=5)
response = requests.head(self.endpoint, timeout=5)
if response.status_code == 405:
return True
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except httpx.RequestError as e:
except requests.RequestException as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@@ -417,7 +417,7 @@ class WeaveDataTrace(BaseTraceInstance):
if not login_status:
raise ValueError("Weave login failed")
else:
logger.info("Weave login successful")
print("Weave login successful")
return True
except Exception as e:
logger.debug("Weave API check failed: %s", str(e))

View File

@@ -513,21 +513,6 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def _get_provider_names(provider_name: str) -> list[str]:
"""
provider_name: `openai` or `langgenius/openai/openai`
return: [`openai`, `langgenius/openai/openai`]
"""
provider_names = [provider_name]
model_provider_id = ModelProviderID(provider_name)
if model_provider_id.is_langgenius():
if "/" in provider_name:
provider_names.append(model_provider_id.provider_name)
else:
provider_names.append(str(model_provider_id))
return provider_names
@staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
"""
@@ -540,10 +525,7 @@ class ProviderManager:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderCredential)
.where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
.order_by(ProviderCredential.created_at.desc())
)
@@ -572,7 +554,7 @@ class ProviderManager:
select(ProviderModelCredential)
.where(
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
ProviderModelCredential.provider_name == provider_name,
ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type,
)

View File

@@ -1,5 +1,4 @@
import json
import logging
import time
import uuid
from typing import Any
@@ -10,24 +9,11 @@ from pymochow import MochowClient # type: ignore
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
from pymochow.configuration import Configuration # type: ignore
from pymochow.exception import ServerError # type: ignore
from pymochow.model.database import Database
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import (
Field,
FilteringIndex,
HNSWParams,
InvertedIndex,
InvertedIndexAnalyzer,
InvertedIndexFieldAttribute,
InvertedIndexParams,
InvertedIndexParseMode,
Schema,
VectorIndex,
) # type: ignore
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -36,8 +22,6 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class BaiduConfig(BaseModel):
endpoint: str
@@ -46,11 +30,9 @@ class BaiduConfig(BaseModel):
api_key: str
database: str
index_type: str = "HNSW"
metric_type: str = "IP"
metric_type: str = "L2"
shard: int = 1
replicas: int = 3
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
inverted_index_parser_mode: str = "COARSE_MODE"
@model_validator(mode="before")
@classmethod
@@ -67,9 +49,13 @@ class BaiduConfig(BaseModel):
class BaiduVector(BaseVector):
vector_index: str = "vector_idx"
filtering_index: str = "filtering_idx"
inverted_index: str = "content_inverted_idx"
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
field_app_id: str = "app_id"
field_annotation_id: str = "annotation_id"
index_vector: str = "vector_idx"
def __init__(self, collection_name: str, config: BaiduConfig):
super().__init__(collection_name)
@@ -88,6 +74,8 @@ class BaiduVector(BaseVector):
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
total_count = len(documents)
batch_size = 1000
@@ -96,31 +84,29 @@ class BaiduVector(BaseVector):
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
for i in range(start, end, 1):
metadata = documents[i].metadata
row = Row(
id=metadata.get("doc_id", str(uuid.uuid4())),
page_content=documents[i].page_content,
metadata=metadata,
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadatas[i]),
app_id=metadatas[i].get("app_id", ""),
annotation_id=metadatas[i].get("annotation_id", ""),
)
rows.append(row)
table.upsert(rows=rows)
# rebuild vector index after upsert finished
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
start_time = time.time()
table.rebuild_index(self.index_vector)
while True:
time.sleep(1)
index = table.describe_index(self.vector_index)
index = table.describe_index(self.index_vector)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
if res and res.code == 0:
return True
return False
@@ -129,73 +115,53 @@ class BaiduVector(BaseVector):
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
def delete_by_metadata_field(self, key: str, value: str):
# Escape double quotes in value to prevent injection
escaped_value = value.replace('"', '\\"')
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] IN({document_ids})'
anns = AnnSearch(
vector_field=VDBField.VECTOR,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)),
filter=filter,
)
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
filter=f"document_id IN ({document_ids})",
)
else:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY],
retrieve_vector=False,
projections=[self.field_id, self.field_text, self.field_metadata],
retrieve_vector=True,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# document ids filter
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] IN({document_ids})'
request = BM25SearchRequest(
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter
)
res = self._db.table(self._collection_name).bm25_search(
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY]
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
# baidu vector database doesn't support bm25 search on current version
return []
def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
meta = row_data.get(VDBField.METADATA_KEY, {})
# Handle both JSON string and dict formats for backward compatibility
if isinstance(meta, str):
try:
import json
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
elif not isinstance(meta, dict):
meta = {}
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta)
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)
return docs
def delete(self):
@@ -212,7 +178,7 @@ class BaiduVector(BaseVector):
client = MochowClient(config)
return client
def _init_database(self) -> Database:
def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
@@ -226,10 +192,10 @@ class BaiduVector(BaseVector):
self._client.create_database(database_name=self._client_config.database)
except ServerError as e:
if e.code == ServerErrCode.DB_ALREADY_EXIST:
return self._client.database(self._client_config.database)
pass
else:
raise
return self._client.database(self._client_config.database)
return
def _table_existed(self) -> bool:
tables = self._db.list_table()
@@ -266,7 +232,7 @@ class BaiduVector(BaseVector):
fields = []
fields.append(
Field(
VDBField.PRIMARY_KEY,
self.field_id,
FieldType.STRING,
primary_key=True,
partition_key=True,
@@ -274,57 +240,24 @@ class BaiduVector(BaseVector):
not_null=True,
)
)
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False))
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False))
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
fields.append(Field(self.field_app_id, FieldType.STRING))
fields.append(Field(self.field_annotation_id, FieldType.STRING))
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
# Construct vector index params
indexes = []
indexes.append(
VectorIndex(
index_name=self.vector_index,
index_name="vector_idx",
index_type=index_type,
field=VDBField.VECTOR,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
)
)
# Filtering index
indexes.append(
FilteringIndex(
index_name=self.filtering_index,
fields=[VDBField.METADATA_KEY],
)
)
# Get analyzer and parse_mode from config
analyzer = getattr(
InvertedIndexAnalyzer,
self._client_config.inverted_index_analyzer,
InvertedIndexAnalyzer.DEFAULT_ANALYZER,
)
parse_mode = getattr(
InvertedIndexParseMode,
self._client_config.inverted_index_parser_mode,
InvertedIndexParseMode.COARSE_MODE,
)
# Inverted index
indexes.append(
InvertedIndex(
index_name=self.inverted_index,
fields=[VDBField.CONTENT_KEY],
params=InvertedIndexParams(
analyzer=analyzer,
parse_mode=parse_mode,
case_sensitive=True,
),
field_attributes=[InvertedIndexFieldAttribute.ANALYZED],
)
)
# Create table
self._db.create_table(
table_name=self._collection_name,
@@ -335,15 +268,11 @@ class BaiduVector(BaseVector):
)
# Wait for table created
timeout = 300 # 5 minutes timeout
start_time = time.time()
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
redis_client.set(table_exist_cache_key, 1, ex=3600)
@@ -367,7 +296,5 @@ class BaiduVectorFactory(AbstractVectorFactory):
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
),
)

View File

@@ -4,7 +4,7 @@ import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
from sqlalchemy import JSON, Column, String
from sqlalchemy.dialects.mysql import LONGTEXT
@@ -117,39 +117,22 @@ class OceanBaseVector(BaseVector):
columns=cols,
vidxs=vidx_params,
)
logger.debug("DEBUG: Table '%s' created successfully", self._collection_name)
if self._hybrid_search_enabled:
# Get parser from config or use default ik parser
parser_name = dify_config.OCEANBASE_FULLTEXT_PARSER or "ik"
allowed_parsers = ["ngram", "beng", "space", "ngram2", "ik", "japanese_ftparser", "thai_ftparser"]
if parser_name not in allowed_parsers:
raise ValueError(
f"Invalid OceanBase full-text parser: {parser_name}. "
f"Allowed values are: {', '.join(allowed_parsers)}"
try:
if self._hybrid_search_enabled:
self._client.create_fts_idx_with_fts_index_param(
table_name=self._collection_name,
fts_idx_param=FtsIndexParam(
index_name="fulltext_index_for_col_text",
field_names=["text"],
parser_type=FtsParser.IK,
),
)
logger.debug("Hybrid search is enabled, parser_name='%s'", parser_name)
logger.debug(
"About to create fulltext index for collection '%s' using parser '%s'",
self._collection_name,
parser_name,
except Exception as e:
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
+ "to support fulltext index and vector index in the same table",
e,
)
try:
sql_command = f"""ALTER TABLE {self._collection_name}
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER {parser_name}"""
logger.debug("DEBUG: Executing SQL: %s", sql_command)
self._client.perform_raw_text_sql(sql_command)
logger.debug("DEBUG: Fulltext index created successfully for '%s'", self._collection_name)
except Exception as e:
logger.exception("Exception occurred while creating fulltext index")
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be "
"4.3.5.1 or above to support fulltext index and vector index in the same table"
) from e
else:
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
self._client.refresh_metadata([self._collection_name])
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@@ -246,7 +229,7 @@ class OceanBaseVector(BaseVector):
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
print(f"Invalid JSON metadata: {metadata_str}")
metadata = {}
metadata["score"] = score
docs.append(Document(page_content=_text, metadata=metadata))

View File

@@ -1,6 +1,5 @@
import array
import json
import logging
import re
import uuid
from typing import Any
@@ -20,8 +19,6 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False
@@ -183,8 +180,8 @@ class OracleVector(BaseVector):
value,
)
conn.commit()
except Exception:
logger.exception("Failed to insert record %s into %s", value[0], self.table_name)
except Exception as e:
print(e)
conn.close()
return pks

View File

@@ -1,5 +1,4 @@
import json
import logging
import uuid
from typing import Any
@@ -24,8 +23,6 @@ from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
@@ -190,8 +187,8 @@ class RelytVector(BaseVector):
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception:
logger.exception("Delete operation failed for collection %s", self._collection_name)
except Exception as e:
print("Delete operation failed:", str(e))
return False
def delete_by_metadata_field(self, key: str, value: str):

View File

@@ -164,8 +164,8 @@ class TiDBVector(BaseVector):
delete_condition = table.c.id.in_(ids)
conn.execute(table.delete().where(delete_condition))
return True
except Exception:
logger.exception("Delete operation failed for collection %s", self._collection_name)
except Exception as e:
print("Delete operation failed:", str(e))
return False
def get_ids_by_metadata_field(self, key: str, value: str):

View File

@@ -417,10 +417,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if db_model is not None:
offload_data = db_model.offload_data
else:
db_model = self._to_db_model(domain_model)
offload_data = db_model.offload_data
offload_data = []
offload_data = db_model.offload_data
if domain_model.inputs is not None:
result = self._truncate_and_upload(
domain_model.inputs,

View File

@@ -1,5 +1,4 @@
import json
import logging
import threading
from collections.abc import Mapping, MutableMapping
from pathlib import Path
@@ -9,8 +8,6 @@ from typing import Any, ClassVar, Optional
class SchemaRegistry:
"""Schema registry manages JSON schemas with version support"""
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
_lock: ClassVar[threading.Lock] = threading.Lock()
@@ -86,7 +83,7 @@ class SchemaRegistry:
self.metadata[uri] = metadata
except (OSError, json.JSONDecodeError) as e:
self.logger.warning("Failed to load schema %s/%s: %s", version, schema_name, e)
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
def get_schema(self, uri: str) -> Any | None:
"""Retrieves a schema by URI with version support"""

View File

@@ -396,10 +396,6 @@ class ApiTool(Tool):
# assemble invoke message based on response type
if parsed_response.is_json and isinstance(parsed_response.content, dict):
yield self.create_json_message(parsed_response.content)
# FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
# We need never break the original flows
yield self.create_text_message(response.text)
else:
# Convert to string if needed and create text message
text_response = (

View File

@@ -41,8 +41,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
node_executions: list[NodeExecutionState] = Field(default_factory=list)
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@@ -104,8 +103,7 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
def start(self) -> None:
"""Mark the graph execution as started."""
@@ -174,7 +172,6 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
)
@@ -198,7 +195,6 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
@@ -209,7 +205,3 @@ class GraphExecution:
)
for item in state.node_executions
}
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

View File

@@ -3,12 +3,11 @@ Event handler implementations for different event types.
"""
import logging
from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -123,15 +122,13 @@ class EventHandler:
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
@@ -164,7 +161,7 @@ class EventHandler:
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
@@ -194,7 +191,7 @@ class EventHandler:
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._update_response_outputs(event)
# Collect the event
self._event_collector.collect(event)
@@ -210,7 +207,6 @@ class EventHandler:
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
result = self._error_handler.handle_node_failure(event)
@@ -231,40 +227,10 @@ class EventHandler:
Args:
event: The node exception event
"""
# Node continues via fail-branch/default-value, treat as completion
# Node continues via fail-branch, so it's technically "succeeded"
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
@@ -276,31 +242,21 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
for variable_name, variable_value in event.node_run_result.outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in outputs.items():
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:

View File

@@ -5,7 +5,6 @@ Unified event manager for collecting and emitting events.
import threading
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import final
from core.workflow.graph_events import GraphEngineEvent
@@ -52,23 +51,43 @@ class ReadWriteLock:
"""Release a write lock."""
self._read_ready.release()
@contextmanager
def read_lock(self):
def read_lock(self) -> "ReadLockContext":
"""Return a context manager for read locking."""
self.acquire_read()
try:
yield
finally:
self.release_read()
return ReadLockContext(self)
@contextmanager
def write_lock(self):
def write_lock(self) -> "WriteLockContext":
"""Return a context manager for write locking."""
self.acquire_write()
try:
yield
finally:
self.release_write()
return WriteLockContext(self)
@final
class ReadLockContext:
"""Context manager for read locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "ReadLockContext":
self._lock.acquire_read()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_read()
@final
class WriteLockContext:
"""Context manager for write locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "WriteLockContext":
self._lock.acquire_write()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_write()
@final

View File

@@ -23,7 +23,6 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@@ -261,23 +260,12 @@ class GraphEngine:
if self._graph_execution.error:
raise self._graph_execution.error
else:
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)
yield GraphRunSucceededEvent(
outputs=self._graph_runtime_state.outputs,
)
except Exception as e:
yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
yield GraphRunFailedEvent(error=str(e))
raise
finally:

View File

@@ -15,7 +15,6 @@ from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
@@ -128,13 +127,6 @@ class DebugLoggingLayer(GraphEngineLayer):
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:
@@ -146,12 +138,6 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events
# Retry before Started because Retry subclasses Started;
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
@@ -181,6 +167,11 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else ""

View File

@@ -147,4 +147,4 @@ class ExecutionLimitsLayer(GraphEngineLayer):
self.logger.debug("Abort command sent to engine")
except Exception:
self.logger.exception("Failed to send abort command")
self.logger.exception("Failed to send abort command: %s")

View File

@@ -1,11 +1,9 @@
import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.variables import IntegerVariable, NoneSegment
@@ -21,7 +19,6 @@ from core.workflow.enums import (
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
@@ -37,7 +34,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -242,8 +238,6 @@ class IterationNode(Node):
self._execute_single_iteration_parallel,
index=index,
item=item,
flask_app=current_app._get_current_object(), # type: ignore
context_vars=contextvars.copy_context(),
)
future_to_index[future] = index
@@ -286,29 +280,26 @@ class IterationNode(Node):
self,
index: int,
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
graph_engine = self._create_graph_engine(index, item)
graph_engine = self._create_graph_engine(index, item)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
def _handle_iteration_success(
self,
@@ -381,16 +372,43 @@ class IterationNode(Node):
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
iteration_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
@@ -426,7 +444,9 @@ class IterationNode(Node):
variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
return variable_mapping
@@ -465,7 +485,7 @@ class IterationNode(Node):
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
elif isinstance(event, GraphRunSucceededEvent):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)

View File

@@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@@ -1,4 +1,3 @@
import contextlib
import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
@@ -128,13 +127,11 @@ class LoopNode(Node):
try:
reach_break_condition = False
if break_conditions:
with contextlib.suppress(ValueError):
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
loop_count = 0
cost_tokens = 0
@@ -298,11 +295,42 @@ class LoopNode(Node):
variable_mapping = {}
# Extract loop node IDs statically from graph_config
# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Get node configs from graph_config
# Create minimal GraphRuntimeState for static analysis
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not loop_graph:
raise ValueError("loop graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
@@ -343,35 +371,12 @@ class LoopNode(Node):
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.
:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""

View File

@@ -33,7 +33,6 @@ file_fields = {
"created_by": fields.String,
"created_at": TimestampField,
"preview_url": fields.String,
"source_url": fields.String,
}

View File

@@ -1,32 +1,10 @@
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from gevent import events as gevent_events
from grpc.experimental import gevent as grpc_gevent # type: ignore
# NOTE(QuantumGhost): here we cannot use post_fork to patch gRPC, as
# grpc_gevent.init_gevent must be called after patching stdlib.
# Gunicorn calls `post_init` before applying monkey patch.
# Use `post_init` to setup gRPC gevent support would cause deadlock and
# some other weird issues.
#
# ref:
# - https://github.com/grpc/grpc/blob/62533ea13879d6ee95c6fda11ec0826ca822c9dd/src/python/grpcio/grpc/experimental/gevent.py
# - https://github.com/gevent/gevent/issues/2060#issuecomment-3016768668
# - https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py#L607-L613
def post_patch(event):
# this function is only called for gevent worker.
# from gevent docs (https://www.gevent.org/api/gevent.monkey.html):
# You can also subscribe to the events to provide additional patching beyond what gevent distributes, either for
# additional standard library modules, or for third-party packages. The suggested time to do this patching is in
# the subscriber for gevent.events.GeventDidPatchBuiltinModulesEvent.
if not isinstance(event, gevent_events.GeventDidPatchBuiltinModulesEvent):
return
def post_fork(server, worker):
# grpc gevent
grpc_gevent.init_gevent()
print("gRPC patched with gevent.", flush=True) # noqa: T201
server.log.info("gRPC patched with gevent.")
pscycogreen_gevent.patch_psycopg()
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
gevent_events.subscribers.append(post_patch)
server.log.info("psycopg2 patched with gevent.")

View File

@@ -1,7 +1,7 @@
import urllib.parse
from dataclasses import dataclass
import httpx
import requests
@dataclass
@@ -58,7 +58,7 @@ class GitHubOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@@ -70,11 +70,11 @@ class GitHubOAuth(OAuth):
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response = requests.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
@@ -112,7 +112,7 @@ class GoogleOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@@ -124,7 +124,7 @@ class GoogleOAuth(OAuth):
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response = requests.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()

View File

@@ -1,7 +1,7 @@
import urllib.parse
from typing import Any
import httpx
import requests
from flask_login import current_user
from sqlalchemy import select
@@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
"Notion-Version": "2022-06-28",
}
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
@@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
message = response_json.get("message", "unknown error")
@@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json()
if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"]
@@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))

View File

@@ -47,7 +47,7 @@ def upgrade():
sa.Column('plugin_id', sa.String(length=255), nullable=False),
sa.Column('auth_type', sa.String(length=255), nullable=False),
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('avatar_url', sa.Text(), nullable=True),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),

View File

@@ -1,37 +0,0 @@
"""remove-builtin-template-user
Revision ID: bf0bcbf45396
Revises: 68519ad5cd18
Create Date: 2025-09-25 16:50:32.245503
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'bf0bcbf45396'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.drop_column('updated_by')
batch_op.drop_column('created_by')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@@ -910,7 +910,7 @@ class AppDatasetJoin(Base):
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def app(self):
@@ -931,7 +931,7 @@ class DatasetQuery(Base):
source_app_id = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(Base):
@@ -1239,6 +1239,15 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
language = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]

View File

@@ -1044,7 +1044,7 @@ class Message(Base):
sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
elif "file-preview" in url:
# get upload file id
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview\?timestamp="
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
result = re.search(upload_file_id_pattern, url)
if not result:
continue
@@ -1055,7 +1055,7 @@ class Message(Base):
sign_url = file_helpers.get_signed_file_url(upload_file_id)
elif "image-preview" in url:
# image-preview is deprecated, use file-preview instead
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview\?timestamp="
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
result = re.search(upload_file_id_pattern, url)
if not result:
continue
@@ -1731,7 +1731,7 @@ class MessageChain(Base):
type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(sa.Text, nullable=True)
output = mapped_column(sa.Text, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(Base):
@@ -1769,7 +1769,7 @@ class MessageAgentThought(Base):
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self) -> list[Any]:
@@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base):
index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(Base):

View File

@@ -35,7 +35,7 @@ class DatasourceProvider(Base):
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default")
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")

View File

@@ -1,11 +1,11 @@
[project]
name = "dify-api"
version = "1.9.0"
version = "2.0.0-beta2"
requires-python = ">=3.11,<3.13"
dependencies = [
"arize-phoenix-otel~=0.9.2",
"authlib==1.6.4",
"authlib==1.3.1",
"azure-identity==1.16.1",
"beautifulsoup4==4.12.2",
"boto3==1.35.99",
@@ -20,7 +20,7 @@ dependencies = [
"flask-migrate~=4.0.7",
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent~=24.11.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-python-client==2.90.0",
@@ -169,7 +169,7 @@ dev = [
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy~=1.17.1",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"locust>=2.40.4",
"sseclient-py>=1.8.0",
]
@@ -211,7 +211,7 @@ vdb = [
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvector==0.2.5",
"pymilvus~=2.5.0",
"pymochow==2.2.9",
"pymochow==1.3.1",
"pyobvector~=0.2.15",
"qdrant-client==1.9.0",
"tablestore==6.2.0",

View File

@@ -2,7 +2,6 @@ import json
import logging
from typing import TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@@ -66,7 +65,7 @@ class AppService:
return None
app_models = db.paginate(
sa.select(App).where(*filters).order_by(App.created_at.desc()),
db.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"],
per_page=args["limit"],
error_out=False,

View File

@@ -1,6 +1,6 @@
import json
import httpx
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
@@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@@ -1,6 +1,6 @@
import json
import httpx
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
@@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@@ -1,6 +1,6 @@
import json
import httpx
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
@@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@@ -1,7 +1,7 @@
import json
from urllib.parse import urljoin
import httpx
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
@@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers):
return httpx.get(url, headers=headers)
return requests.get(url, headers=headers)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@@ -115,12 +115,12 @@ class DatasetService:
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
query = query.where(
sa.or_(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
sa.and_(
db.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
sa.and_(
db.and_(
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
Dataset.id.in_(permitted_dataset_ids),
),
@@ -128,9 +128,9 @@ class DatasetService:
)
else:
query = query.where(
sa.or_(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
sa.and_(
db.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
)
@@ -1879,7 +1879,7 @@ class DocumentService:
# for notion_info in notion_info_list:
# workspace_id = notion_info.workspace_id
# data_source_binding = DataSourceOauthBinding.query.filter(
# sa.and_(
# db.and_(
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
# DataSourceOauthBinding.provider == "notion",
# DataSourceOauthBinding.disabled == False,

View File

@@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
@@ -128,10 +128,3 @@ class KnowledgeConfiguration(BaseModel):
if v is None:
return ""
return v
class PipelineBuiltInTemplateEntity(BaseModel):
template_id: str | None = None
name: str
description: str
language: str

View File

@@ -1,6 +1,6 @@
import os
import httpx
import requests
class OperationService:
@@ -12,7 +12,7 @@ class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
response = requests.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@@ -46,11 +46,7 @@ limit 1000"""
record_id = str(i.id)
provider_name = str(i.provider_name)
retrieval_model = i.retrieval_model
logger.debug(
"Processing dataset %s with retrieval model of type %s",
record_id,
type(retrieval_model),
)
print(type(retrieval_model))
if record_id in failed_ids:
continue

View File

@@ -471,7 +471,7 @@ class PluginMigration:
total_failed_tenant = 0
while True:
# paginate
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0:
break

View File

@@ -74,4 +74,5 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"chunk_structure": pipeline_template.chunk_structure,
"export_data": pipeline_template.yaml_content,
"graph": graph_data,
"created_by": pipeline_template.created_user_name,
}

View File

@@ -8,7 +8,6 @@ from datetime import UTC, datetime
from typing import Any, Union, cast
from uuid import uuid4
import yaml
from flask_login import current_user
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, sessionmaker
@@ -61,7 +60,6 @@ from models.dataset import ( # type: ignore
Document,
DocumentPipelineExecutionLog,
Pipeline,
PipelineBuiltInTemplate,
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
@@ -78,7 +76,6 @@ from repositories.factory import DifyAPIRepositoryFactory
from services.datasource_provider_service import DatasourceProviderService
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineBuiltInTemplateEntity,
PipelineTemplateInfoEntity,
)
from services.errors.app import WorkflowHashNotEqualError
@@ -1330,14 +1327,14 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_execution_log = (
document_pipeline_excution_log = (
db.session.query(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
)
if not document_pipeline_execution_log:
if not document_pipeline_excution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@@ -1349,10 +1346,10 @@ class RagPipelineService:
workflow=workflow,
user=user,
args={
"inputs": document_pipeline_execution_log.input_data,
"start_node_id": document_pipeline_execution_log.datasource_node_id,
"datasource_type": document_pipeline_execution_log.datasource_type,
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
"inputs": document_pipeline_excution_log.input_data,
"start_node_id": document_pipeline_excution_log.datasource_node_id,
"datasource_type": document_pipeline_excution_log.datasource_type,
"datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)],
"original_document_id": document.id,
},
invoke_from=InvokeFrom.PUBLISHED,
@@ -1384,8 +1381,8 @@ class RagPipelineService:
datasource_nodes = workflow.graph_dict.get("nodes", [])
datasource_plugins = []
for datasource_node in datasource_nodes:
if datasource_node.get("data", {}).get("type") == "datasource":
datasource_node_data = datasource_node["data"]
if datasource_node.get("type") == "datasource":
datasource_node_data = datasource_node.get("data", {})
if not datasource_node_data:
continue
@@ -1457,140 +1454,3 @@ class RagPipelineService:
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline
def install_built_in_pipeline_template(
self, args: PipelineBuiltInTemplateEntity, file_content: str, auth_token: str
) -> None:
"""
Install built-in pipeline template
Args:
args: Pipeline built-in template entity with template metadata
file_content: YAML content of the pipeline template
auth_token: Authentication token for authorization
Raises:
ValueError: If validation fails or template processing errors occur
"""
# Validate authentication
self._validate_auth_token(auth_token)
# Parse and validate template content
pipeline_template_dsl = self._parse_template_content(file_content)
# Extract template metadata
icon = self._extract_icon_metadata(pipeline_template_dsl)
chunk_structure = self._extract_chunk_structure(pipeline_template_dsl)
# Prepare template data
template_data = {
"name": args.name,
"description": args.description,
"chunk_structure": chunk_structure,
"icon": icon,
"language": args.language,
"yaml_content": file_content,
}
# Use transaction for database operations
try:
if args.template_id:
self._update_existing_template(args.template_id, template_data)
else:
self._create_new_template(template_data)
db.session.commit()
except Exception as e:
db.session.rollback()
raise ValueError(f"Failed to install pipeline template: {str(e)}")
def _validate_auth_token(self, auth_token: str) -> None:
"""Validate the authentication token"""
config_auth_token = dify_config.UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN
if not config_auth_token:
raise ValueError("Auth token configuration is required")
if config_auth_token != auth_token:
raise ValueError("Auth token is incorrect")
def _parse_template_content(self, file_content: str) -> dict:
"""Parse and validate YAML template content"""
try:
pipeline_template_dsl = yaml.safe_load(file_content)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML content: {str(e)}")
if not pipeline_template_dsl:
raise ValueError("Pipeline template DSL is required")
return pipeline_template_dsl
def _extract_icon_metadata(self, pipeline_template_dsl: dict) -> dict:
"""Extract icon metadata from template DSL"""
rag_pipeline_info = pipeline_template_dsl.get("rag_pipeline", {})
return {
"icon": rag_pipeline_info.get("icon", "📙"),
"icon_type": rag_pipeline_info.get("icon_type", "emoji"),
"icon_background": rag_pipeline_info.get("icon_background", "#FFEAD5"),
"icon_url": rag_pipeline_info.get("icon_url"),
}
def _extract_chunk_structure(self, pipeline_template_dsl: dict) -> str:
"""Extract chunk structure from template DSL"""
nodes = pipeline_template_dsl.get("workflow", {}).get("graph", {}).get("nodes", [])
# Use generator expression for efficiency
chunk_structure = next(
(
node.get("data", {}).get("chunk_structure")
for node in nodes
if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_INDEX.value
),
None
)
if not chunk_structure:
raise ValueError("Chunk structure is required in template")
return chunk_structure
def _update_existing_template(self, template_id: str, template_data: dict) -> None:
"""Update an existing pipeline template"""
pipeline_built_in_template = (
db.session.query(PipelineBuiltInTemplate)
.filter(PipelineBuiltInTemplate.id == template_id)
.first()
)
if not pipeline_built_in_template:
raise ValueError(f"Pipeline built-in template not found: {template_id}")
# Update template fields
for key, value in template_data.items():
setattr(pipeline_built_in_template, key, value)
db.session.add(pipeline_built_in_template)
def _create_new_template(self, template_data: dict) -> None:
"""Create a new pipeline template"""
# Get the next available position
position = self._get_next_position(template_data["language"])
# Add additional fields for new template
template_data.update({
"position": position,
"install_count": 0,
"copyright": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT,
"privacy_policy": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY,
})
new_template = PipelineBuiltInTemplate(**template_data)
db.session.add(new_template)
def _get_next_position(self, language: str) -> int:
"""Get the next available position for a template in the specified language"""
max_position = (
db.session.query(func.max(PipelineBuiltInTemplate.position))
.filter(PipelineBuiltInTemplate.language == language)
.scalar()
)
return (max_position or 0) + 1

View File

@@ -685,24 +685,12 @@ class RagPipelineDslService:
workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node_data.get("dataset_ids", [])
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [

View File

@@ -1,5 +1,4 @@
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from uuid import uuid4
@@ -18,8 +17,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import Knowledge
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
@@ -38,11 +35,11 @@ class RagPipelineTransformService:
indexing_technique = dataset.indexing_technique
if not datasource_type and not indexing_technique:
return self._transform_to_empty_pipeline(dataset)
return self._transfrom_to_empty_pipeline(dataset)
doc_form = dataset.doc_form
if not doc_form:
return self._transform_to_empty_pipeline(dataset)
return self._transfrom_to_empty_pipeline(dataset)
retrieval_model = dataset.retrieval_model
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
# deal dependencies
@@ -260,10 +257,10 @@ class RagPipelineTransformService:
if plugin_unique_identifier:
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
if need_install_plugin_unique_identifiers:
logger.debug("Installing missing pipeline plugins %s", need_install_plugin_unique_identifiers)
print(need_install_plugin_unique_identifiers)
PluginService.install_from_marketplace_pkg(tenant_id, need_install_plugin_unique_identifiers)
def _transform_to_empty_pipeline(self, dataset: Dataset):
def _transfrom_to_empty_pipeline(self, dataset: Dataset):
pipeline = Pipeline(
tenant_id=dataset.tenant_id,
name=dataset.name,

View File

@@ -1,6 +1,5 @@
import uuid
import sqlalchemy as sa
from flask_login import current_user
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
@@ -19,7 +18,7 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results

View File

@@ -262,14 +262,6 @@ class VariableTruncator:
target_length = self._array_element_limit
for i, item in enumerate(value):
# Dirty fix:
# The output of `Start` node may contain list of `File` elements,
# causing `AssertionError` while invoking `_truncate_json_primitives`.
#
# This check ensures that `list[File]` are handled separately
if isinstance(item, File):
truncated_value.append(item)
continue
if i >= target_length:
return _PartResult(truncated_value, used_size, True)
if i > 0:

View File

@@ -3,7 +3,7 @@ import json
from dataclasses import dataclass
from typing import Any
import httpx
import requests
from flask_login import current_user
from core.helper import encrypter
@@ -216,7 +216,7 @@ class WebsiteService:
@classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages:
response = httpx.get(
response = requests.get(
f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
@@ -224,7 +224,7 @@ class WebsiteService:
raise ValueError("Failed to crawl:")
return {"status": "active", "data": response.json().get("data")}
else:
response = httpx.post(
response = requests.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": request.url,
@@ -287,7 +287,7 @@ class WebsiteService:
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = httpx.post(
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
@@ -303,7 +303,7 @@ class WebsiteService:
}
if crawl_status_data["status"] == "completed":
response = httpx.post(
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
@@ -362,7 +362,7 @@ class WebsiteService:
@classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id:
response = httpx.get(
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
@@ -371,7 +371,7 @@ class WebsiteService:
return dict(response.json().get("data", {}))
else:
# Get crawl status first
status_response = httpx.post(
status_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
@@ -381,7 +381,7 @@ class WebsiteService:
raise ValueError("Crawl job is not completed")
# Get processed data
data_response = httpx.post(
data_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},

View File

@@ -450,8 +450,7 @@ class WorkflowService:
)
if not default_provider:
# plugin does not require credentials, skip
return
raise ValueError("No default credential found")
# Check credential policy compliance using the default credential ID
from core.helper.credential_utils import check_credential_policy_compliance

View File

@@ -2,7 +2,6 @@ import logging
import time
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import select
@@ -52,7 +51,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
sa.and_(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,

View File

@@ -1,14 +1,12 @@
"""Integration tests for ChatMessageApi permission verification."""
import uuid
from types import SimpleNamespace
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.app import completion as completion_api
from controllers.console.app import message as message_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
@@ -101,106 +99,3 @@ class TestChatMessageApiPermissions:
)
assert response.status_code == status
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_get_requires_edit_permission(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Ensure GET chat-messages endpoint enforces edit permissions."""
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
conversation_id = uuid.uuid4()
created_at = naive_utc_now()
mock_conversation = SimpleNamespace(id=str(conversation_id), app_id=str(mock_app_model.id))
mock_message = SimpleNamespace(
id=str(uuid.uuid4()),
conversation_id=str(conversation_id),
inputs=[],
query="hello",
message=[{"text": "hello"}],
message_tokens=0,
re_sign_file_url_answer="",
answer_tokens=0,
provider_response_latency=0.0,
from_source="console",
from_end_user_id=None,
from_account_id=mock_account.id,
feedbacks=[],
workflow_run_id=None,
annotation=None,
annotation_hit_history=None,
created_at=created_at,
agent_thoughts=[],
message_files=[],
message_metadata_dict={},
status="success",
error="",
parent_message_id=None,
)
class MockQuery:
def __init__(self, model):
self.model = model
def where(self, *args, **kwargs):
return self
def first(self):
if getattr(self.model, "__name__", "") == "Conversation":
return mock_conversation
return None
def order_by(self, *args, **kwargs):
return self
def limit(self, *_):
return self
def all(self):
if getattr(self.model, "__name__", "") == "Message":
return [mock_message]
return []
mock_session = mock.Mock()
mock_session.query.side_effect = MockQuery
mock_session.scalar.return_value = False
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
monkeypatch.setattr(message_api, "current_user", mock_account)
class DummyPagination:
def __init__(self, data, limit, has_more):
self.data = data
self.limit = limit
self.has_more = has_more
monkeypatch.setattr(message_api, "InfiniteScrollPagination", DummyPagination)
mock_account.role = role
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/chat-messages",
headers=auth_header,
query_string={"conversation_id": str(conversation_id)},
)
assert response.status_code == status

View File

@@ -1,8 +1,8 @@
import os
from typing import Literal
import httpx
import pytest
import requests
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
from core.tools.entities.common_entities import I18nObject
@@ -27,11 +27,13 @@ class MockedHttp:
@classmethod
def requests_request(
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> httpx.Response:
) -> requests.Response:
"""
Mocked httpx.request
Mocked requests.request
"""
request = httpx.Request(method, url)
request = requests.PreparedRequest()
request.method = method
request.url = url
if url.endswith("/tools"):
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
code=0, message="success", data=cls.list_tools()
@@ -39,7 +41,8 @@ class MockedHttp:
else:
raise ValueError("")
response = httpx.Response(status_code=200)
response = requests.Response()
response.status_code = 200
response.request = request
response._content = content.encode("utf-8")
return response
@@ -51,7 +54,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK_SWITCH:
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
def unpatch():
monkeypatch.undo()

View File

@@ -100,8 +100,8 @@ class MockBaiduVectorDBClass:
"row": {
"id": primary_key.get("id"),
"vector": [0.23432432, 0.8923744, 0.89238432],
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
},
"code": 0,
"msg": "Success",
@@ -127,8 +127,8 @@ class MockBaiduVectorDBClass:
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
},
"distance": 0.1,
"score": 0.5,

View File

@@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
import os
import time
import httpx
import requests
from clickzetta import connect
@@ -66,7 +66,7 @@ def test_dify_api():
max_retries = 30
for i in range(max_retries):
try:
response = httpx.get(f"{base_url}/console/api/health")
response = requests.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
break

View File

@@ -173,7 +173,7 @@ class DifyTestContainers:
# Start Dify Plugin Daemon container for plugin management
# Dify Plugin Daemon provides plugin lifecycle management and execution
logger.info("Initializing Dify Plugin Daemon container...")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.2.0-local")
self.dify_plugin_daemon.with_exposed_ports(5002)
self.dify_plugin_daemon.env = {
"DB_HOST": db_host,

View File

@@ -13,7 +13,6 @@ import pytest
from faker import Faker
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
@@ -203,6 +202,7 @@ class TestBatchCleanDocumentTask:
UploadFile: Created upload file instance
"""
fake = Faker()
from datetime import datetime
from models.enums import CreatorUserRole
@@ -216,7 +216,7 @@ class TestBatchCleanDocumentTask:
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=naive_utc_now(),
created_at=datetime.utcnow(),
used=False,
)

View File

@@ -201,9 +201,9 @@ class TestOAuthCallback:
mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception
import httpx
import requests
request_exception = httpx.RequestError("OAuth error")
request_exception = requests.exceptions.RequestException("OAuth error")
request_exception.response = MagicMock()
request_exception.response.text = str(exception)

View File

@@ -1,5 +1,6 @@
"""Unit tests for workflow node execution conflict handling."""
from datetime import datetime
from unittest.mock import MagicMock, Mock
import psycopg2.errors
@@ -15,7 +16,6 @@ from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import NodeType
from libs.datetime_utils import naive_utc_now
from models import Account, WorkflowNodeExecutionTriggeredFrom
@@ -74,7 +74,7 @@ class TestWorkflowNodeExecutionConflictHandling:
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
created_at=datetime.utcnow(),
)
original_id = execution.id
@@ -112,7 +112,7 @@ class TestWorkflowNodeExecutionConflictHandling:
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=naive_utc_now(),
created_at=datetime.utcnow(),
)
# Save should update existing record
@@ -157,7 +157,7 @@ class TestWorkflowNodeExecutionConflictHandling:
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
created_at=datetime.utcnow(),
)
# Save should raise IntegrityError after max retries
@@ -199,7 +199,7 @@ class TestWorkflowNodeExecutionConflictHandling:
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
created_at=datetime.utcnow(),
)
# Save should raise error immediately

View File

@@ -1,120 +0,0 @@
"""Tests for graph engine event handlers."""
from __future__ import annotations
from datetime import datetime
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
class _StubEdgeProcessor:
"""Minimal edge processor stub for tests."""
class _StubErrorHandler:
"""Minimal error handler stub for tests."""
class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""
def __init__(self, node_id: str) -> None:
self.id = node_id
self.state = NodeState.UNKNOWN
self.title = "Stub Node"
self.execution_type = NodeExecutionType.EXECUTABLE
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
"""Construct an EventHandler with in-memory dependencies for testing."""
node = _StubNode(node_id)
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
variable_pool = VariablePool()
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_execution = GraphExecution(workflow_id="test-workflow")
event_manager = EventManager()
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
handler = EventHandler(
graph=graph,
graph_runtime_state=runtime_state,
graph_execution=graph_execution,
response_coordinator=response_coordinator,
event_collector=event_manager,
edge_processor=_StubEdgeProcessor(),
state_manager=state_manager,
error_handler=_StubErrorHandler(),
)
return handler, event_manager, graph_execution
def test_retry_does_not_emit_additional_start_event() -> None:
"""Ensure retry attempts do not produce duplicate start events."""
node_id = "test-node"
handler, event_manager, graph_execution = _build_event_handler(node_id)
execution_id = "exec-1"
node_type = NodeType.CODE
start_time = datetime.utcnow()
start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(start_event)
retry_event = NodeRunRetryEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
error="boom",
retry_index=1,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="boom",
error_type="TestError",
),
)
handler.dispatch(retry_event)
# Simulate the node starting execution again after retry
second_start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(second_start_event)
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
node_execution = graph_execution.get_or_create_node_execution(node_id)
assert node_execution.retry_count == 1

View File

@@ -10,18 +10,11 @@ import time
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
# Import the test framework from the new module
from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
@@ -728,39 +721,3 @@ def test_event_sequence_validation_with_table_tests():
else:
assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
def test_graph_run_emits_partial_success_when_node_failure_recovered():
runner = TableTestRunner()
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
query="hello",
use_mock_factory=True,
mock_config=mock_config,
)
llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)
events = list(engine.run())
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
assert partial_event.exceptions_count == 1
assert partial_event.outputs.get("answer") == "fallback response"
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)

View File

@@ -0,0 +1,65 @@
import pytest
pytest.skip(
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
allow_module_level=True,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

View File

@@ -1,8 +1,8 @@
import urllib.parse
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
@@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
({}, None, True),
],
)
@patch("httpx.post")
@patch("requests.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
):
@@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
),
],
)
@patch("httpx.get")
@patch("requests.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock()
user_response.json.return_value = user_data
@@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
assert user_info.name == user_data["name"]
assert user_info.email == expected_email
@patch("httpx.get")
@patch("requests.get")
def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = httpx.RequestError("Network error")
mock_get.side_effect = requests.exceptions.RequestException("Network error")
with pytest.raises(httpx.RequestError):
with pytest.raises(requests.exceptions.RequestException):
oauth.get_raw_user_info("test_token")
@@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({}, None, True),
],
)
@patch("httpx.post")
@patch("requests.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
):
@@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
],
)
@patch("httpx.get")
@patch("requests.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data
mock_get.return_value = mock_response
@@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
@pytest.mark.parametrize(
"exception_type",
[
httpx.HTTPError,
httpx.ConnectError,
httpx.TimeoutException,
requests.exceptions.HTTPError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
],
)
@patch("httpx.get")
@patch("requests.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error")

View File

@@ -1,83 +0,0 @@
import importlib
import types
import pytest
from models.model import Message
@pytest.fixture(autouse=True)
def patch_file_helpers(monkeypatch: pytest.MonkeyPatch):
"""
Patch file_helpers.get_signed_file_url to a deterministic stub.
"""
model_module = importlib.import_module("models.model")
dummy = types.SimpleNamespace(get_signed_file_url=lambda fid: f"https://signed.example/{fid}")
# Inject/override file_helpers on models.model
monkeypatch.setattr(model_module, "file_helpers", dummy, raising=False)
def _wrap_md(url: str) -> str:
"""
Wrap a raw URL into the markdown that re_sign_file_url_answer expects:
[link](<url>)
"""
return f"please click [file]({url}) to download."
def test_file_preview_valid_replaced():
"""
Valid file-preview URL must be re-signed:
- Extract upload_file_id correctly
- Replace the original URL with the signed URL
"""
upload_id = "abc-123"
url = f"/files/{upload_id}/file-preview?timestamp=111&nonce=222&sign=333"
msg = Message(answer=_wrap_md(url))
out = msg.re_sign_file_url_answer
assert f"https://signed.example/{upload_id}" in out
assert url not in out
def test_file_preview_misspelled_not_replaced():
"""
Misspelled endpoint 'file-previe?timestamp=' should NOT be rewritten.
"""
upload_id = "zzz-001"
# path deliberately misspelled: file-previe? (missing 'w')
# and we append &note=file-preview to trick the old `"file-preview" in url` check.
url = f"/files/{upload_id}/file-previe?timestamp=111&nonce=222&sign=333&note=file-preview"
original = _wrap_md(url)
msg = Message(answer=original)
out = msg.re_sign_file_url_answer
# Expect NO replacement, should not rewrite misspelled file-previe URL
assert out == original
def test_image_preview_valid_replaced():
"""
Valid image-preview URL must be re-signed.
"""
upload_id = "img-789"
url = f"/files/{upload_id}/image-preview?timestamp=123&nonce=456&sign=789"
msg = Message(answer=_wrap_md(url))
out = msg.re_sign_file_url_answer
assert f"https://signed.example/{upload_id}" in out
assert url not in out
def test_image_preview_misspelled_not_replaced():
"""
Misspelled endpoint 'image-previe?timestamp=' should NOT be rewritten.
"""
upload_id = "img-err-42"
url = f"/files/{upload_id}/image-previe?timestamp=1&nonce=2&sign=3&note=image-preview"
original = _wrap_md(url)
msg = Message(answer=original)
out = msg.re_sign_file_url_answer
# Expect NO replacement, should not rewrite misspelled image-previe URL
assert out == original

View File

@@ -6,8 +6,8 @@ import json
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
import httpx
import pytest
import requests
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService
@@ -26,7 +26,7 @@ class TestAuthIntegration:
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage"""
@@ -47,7 +47,7 @@ class TestAuthIntegration:
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response()
@@ -97,7 +97,7 @@ class TestAuthIntegration:
assert "another_secret" not in factory_str
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety"""
@@ -142,31 +142,31 @@ class TestAuthIntegration:
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling"""
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
mock_http.return_value = mock_response
# PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((httpx.HTTPError, Exception)):
with pytest.raises((requests.exceptions.HTTPError, Exception)):
factory.validate_credentials()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures"""
mock_http.side_effect = httpx.RequestError("Network timeout")
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
with pytest.raises(httpx.RequestError):
with pytest.raises(requests.exceptions.RequestException):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_session.commit.assert_not_called()

View File

@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from services.auth.firecrawl.firecrawl import FirecrawlAuth
@@ -64,7 +64,7 @@ class TestFirecrawlAuth:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@@ -95,7 +95,7 @@ class TestFirecrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@@ -115,7 +115,7 @@ class TestFirecrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@@ -134,13 +134,13 @@ class TestFirecrawlAuth:
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
@@ -162,7 +162,7 @@ class TestFirecrawlAuth:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
@@ -179,12 +179,12 @@ class TestFirecrawlAuth:
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.requests.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
with pytest.raises(httpx.TimeoutException) as exc_info:
with pytest.raises(requests.Timeout) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message

View File

@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from services.auth.jina.jina import JinaAuth
@@ -35,7 +35,7 @@ class TestJinaAuth:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation"""
mock_response = MagicMock()
@@ -53,7 +53,7 @@ class TestJinaAuth:
json={"url": "https://example.com"},
)
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
@@ -68,7 +68,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error"""
mock_response = MagicMock()
@@ -83,7 +83,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
@@ -98,7 +98,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
@@ -114,7 +114,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response"""
mock_response = MagicMock()
@@ -130,15 +130,15 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.requests.post")
def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors"""
mock_post.side_effect = httpx.ConnectError("Network error")
mock_post.side_effect = requests.ConnectionError("Network error")
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(httpx.ConnectError):
with pytest.raises(requests.ConnectionError):
auth.validate_credentials()
def test_should_not_expose_api_key_in_error_messages(self):

View File

@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from services.auth.watercrawl.watercrawl import WatercrawlAuth
@@ -64,7 +64,7 @@ class TestWatercrawlAuth:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@@ -87,7 +87,7 @@ class TestWatercrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@@ -107,7 +107,7 @@ class TestWatercrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@@ -126,13 +126,13 @@ class TestWatercrawlAuth:
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message)
@@ -154,7 +154,7 @@ class TestWatercrawlAuth:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
@@ -179,7 +179,7 @@ class TestWatercrawlAuth:
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock()
@@ -193,12 +193,12 @@ class TestWatercrawlAuth:
# Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.requests.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
with pytest.raises(httpx.TimeoutException) as exc_info:
with pytest.raises(requests.Timeout) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message

Some files were not shown because too many files have changed in this diff Show More