mirror of
https://github.com/langgenius/dify.git
synced 2026-04-06 04:33:13 +08:00
Compare commits
2 Commits
deploy/rag
...
fix/templa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f63d3aa45 | ||
|
|
2c9c246052 |
@@ -1,16 +1,15 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@@ -8,7 +8,8 @@ on:
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "deploy/rag-dev"
|
||||
- "feat/rag-2"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
4
.github/workflows/deploy-dev.yml
vendored
4
.github/workflows/deploy-dev.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/dev"
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -230,8 +230,4 @@ api/.env.backup
|
||||
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
scripts/stress-test/reports/
|
||||
5
Makefile
5
Makefile
@@ -61,9 +61,8 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@echo "🔧 Running ruff format and check with fixes..."
|
||||
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
||||
@@ -304,8 +304,6 @@ BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
|
||||
@@ -30,7 +30,6 @@ select = [
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"T201", # print-found
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
@@ -92,18 +91,11 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import logging
|
||||
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
print("gRPC patched with gevent.", flush=True) # noqa: T201
|
||||
_log("gRPC patched with gevent.")
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
|
||||
_log("psycopg2 patched with gevent.")
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
371
api/commands.py
371
api/commands.py
@@ -10,7 +10,6 @@ from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@@ -62,30 +61,31 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@@ -100,21 +100,22 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@@ -138,24 +139,25 @@ def reset_encrypt_key_pair():
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@@ -180,15 +182,14 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@@ -202,27 +203,26 @@ def migrate_annotation_vector_database():
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
@@ -739,18 +739,18 @@ where sites.id is null limit 1000"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
logger.info("App %s not found", app_id)
|
||||
print(f"App {app_id} not found")
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
print(f"Fix failed for app {app.id}")
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
print(f"Fixing missing site for app {app.id}")
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
@@ -1448,52 +1448,41 @@ def transform_datasource_credentials():
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
continue
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
@@ -1506,48 +1495,37 @@ def transform_datasource_credentials():
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
continue
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
@@ -1560,45 +1538,36 @@ def transform_datasource_credentials():
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
continue
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@@ -18,18 +18,3 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN: str = Field(
|
||||
description="Token for uploading knowledge pipeline template.",
|
||||
default="",
|
||||
)
|
||||
|
||||
KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT: str = Field(
|
||||
description="Knowledge pipeline template copyright.",
|
||||
default="Copyright 2023 Dify",
|
||||
)
|
||||
|
||||
KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY: str = Field(
|
||||
description="Knowledge pipeline template privacy policy.",
|
||||
default="https://dify.ai",
|
||||
)
|
||||
|
||||
@@ -41,13 +41,3 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
|
||||
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
|
||||
default="DEFAULT_ANALYZER",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
@@ -37,15 +37,3 @@ class OceanBaseVectorConfig(BaseSettings):
|
||||
"with older versions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
||||
description=(
|
||||
"Fulltext parser to use for text indexing. "
|
||||
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
|
||||
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
|
||||
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
|
||||
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
|
||||
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
|
||||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,10 +30,10 @@ class NacosHttpClient:
|
||||
params = {}
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except httpx.RequestError as e:
|
||||
except requests.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
@@ -78,7 +78,7 @@ class NacosHttpClient:
|
||||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = httpx.request("POST", url, headers=None, params=params)
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
@@ -71,7 +70,7 @@ class CompletionConversationApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
query = db.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
@@ -237,7 +236,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
|
||||
@@ -62,9 +62,6 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
|
||||
@@ -50,9 +50,8 @@ class DailyMessageStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -188,9 +187,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -261,9 +259,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -343,9 +340,8 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -430,9 +426,8 @@ LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -507,9 +502,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -582,9 +576,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
@@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except httpx.HTTPStatusError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except httpx.HTTPStatusError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
@@ -101,10 +101,8 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except httpx.RequestError as e:
|
||||
error_text = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_text = e.response.text
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
||||
@@ -782,6 +782,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@@ -808,7 +809,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@@ -838,6 +838,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@@ -862,7 +863,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@@ -212,13 +211,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
|
||||
@@ -14,10 +14,7 @@ from controllers.console.wraps import (
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
PipelineBuiltInTemplateEntity,
|
||||
PipelineTemplateInfoEntity,
|
||||
)
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,6 +26,12 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -143,186 +146,6 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class PipelineTemplateInstallApi(Resource):
|
||||
"""API endpoint for installing built-in pipeline templates"""
|
||||
|
||||
def post(self):
|
||||
"""
|
||||
Install a built-in pipeline template
|
||||
|
||||
Args:
|
||||
template_id: The template ID from URL parameter
|
||||
|
||||
Returns:
|
||||
Success response or error with appropriate HTTP status
|
||||
"""
|
||||
try:
|
||||
# Extract and validate Bearer token
|
||||
auth_token = self._extract_bearer_token()
|
||||
|
||||
# Parse and validate request parameters
|
||||
template_args = self._parse_template_args()
|
||||
|
||||
# Process uploaded template file
|
||||
file_content = self._process_template_file()
|
||||
|
||||
# Create template entity
|
||||
pipeline_built_in_template_entity = PipelineBuiltInTemplateEntity(**template_args)
|
||||
|
||||
# Install the template
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.install_built_in_pipeline_template(
|
||||
pipeline_built_in_template_entity, file_content, auth_token
|
||||
)
|
||||
|
||||
return {"result": "success", "message": "Template installed successfully"}, 200
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Validation error in template installation")
|
||||
return {"error": str(e)}, 400
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in template installation")
|
||||
return {"error": "An unexpected error occurred during template installation"}, 500
|
||||
|
||||
def _extract_bearer_token(self) -> str:
|
||||
"""
|
||||
Extract and validate Bearer token from Authorization header
|
||||
|
||||
Returns:
|
||||
The extracted token string
|
||||
|
||||
Raises:
|
||||
ValueError: If token is missing or invalid
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "").strip()
|
||||
|
||||
if not auth_header:
|
||||
raise ValueError("Authorization header is required")
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise ValueError("Authorization header must start with 'Bearer '")
|
||||
|
||||
token_parts = auth_header.split(" ", 1)
|
||||
if len(token_parts) != 2:
|
||||
raise ValueError("Invalid Authorization header format")
|
||||
|
||||
auth_token = token_parts[1].strip()
|
||||
if not auth_token:
|
||||
raise ValueError("Bearer token cannot be empty")
|
||||
|
||||
return auth_token
|
||||
|
||||
def _parse_template_args(self) -> dict:
|
||||
"""
|
||||
Parse and validate template arguments from form data
|
||||
|
||||
Args:
|
||||
template_id: The template ID from URL
|
||||
|
||||
Returns:
|
||||
Dictionary of validated template arguments
|
||||
"""
|
||||
# Use reqparse for consistent parameter parsing
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
"template_id",
|
||||
type=str,
|
||||
location="form",
|
||||
required=False,
|
||||
help="Template ID for updating existing template"
|
||||
)
|
||||
parser.add_argument(
|
||||
"language",
|
||||
type=str,
|
||||
location="form",
|
||||
required=True,
|
||||
default="en-US",
|
||||
choices=["en-US", "zh-CN", "ja-JP"],
|
||||
help="Template language code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"name",
|
||||
type=str,
|
||||
location="form",
|
||||
required=True,
|
||||
default="New Pipeline Template",
|
||||
help="Template name (1-200 characters)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
location="form",
|
||||
required=False,
|
||||
default="",
|
||||
help="Template description (max 1000 characters)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Additional validation
|
||||
if args.get("name"):
|
||||
args["name"] = self._validate_name(args["name"])
|
||||
|
||||
if args.get("description") and len(args["description"]) > 1000:
|
||||
raise ValueError("Description must not exceed 1000 characters")
|
||||
|
||||
# Filter out None values
|
||||
return {k: v for k, v in args.items() if v is not None}
|
||||
|
||||
def _validate_name(self, name: str) -> str:
|
||||
"""
|
||||
Validate template name
|
||||
|
||||
Args:
|
||||
name: Template name to validate
|
||||
|
||||
Returns:
|
||||
Validated and trimmed name
|
||||
|
||||
Raises:
|
||||
ValueError: If name is invalid
|
||||
"""
|
||||
name = name.strip()
|
||||
if not name or len(name) < 1 or len(name) > 200:
|
||||
raise ValueError("Template name must be between 1 and 200 characters")
|
||||
return name
|
||||
|
||||
def _process_template_file(self) -> str:
|
||||
"""
|
||||
Process and validate uploaded template file
|
||||
|
||||
Returns:
|
||||
File content as string
|
||||
|
||||
Raises:
|
||||
ValueError: If file is missing or invalid
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
raise ValueError("Template file is required")
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
# Validate file
|
||||
if not file or not file.filename:
|
||||
raise ValueError("No file selected")
|
||||
|
||||
filename = file.filename.strip()
|
||||
if not filename:
|
||||
raise ValueError("File name cannot be empty")
|
||||
|
||||
# Check file extension
|
||||
if not filename.lower().endswith(".pipeline"):
|
||||
raise ValueError("Template file must be a pipeline file (.pipeline)")
|
||||
|
||||
try:
|
||||
file_content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError("Template file must be valid UTF-8 text")
|
||||
|
||||
return file_content
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
@@ -339,7 +162,3 @@ api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateInstallApi,
|
||||
"/rag/pipeline/built-in/templates/install",
|
||||
)
|
||||
@@ -118,14 +118,12 @@ class RagPipelineExportApi(Resource):
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
||||
)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from packaging import version
|
||||
|
||||
@@ -57,11 +57,7 @@ class VersionApi(Resource):
|
||||
return result
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
)
|
||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
|
||||
@@ -79,12 +79,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -427,9 +427,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -468,7 +465,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -563,7 +559,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
||||
@@ -86,12 +86,29 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -51,12 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -120,81 +119,15 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph
|
||||
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
(either single iteration or single loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
|
||||
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool, graph_runtime_state)
|
||||
|
||||
Raises:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
||||
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
node_id: The node ID to execute
|
||||
user_inputs: User inputs for the node
|
||||
graph_runtime_state: The graph runtime state
|
||||
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
|
||||
node_type_label: Label for error messages ('iteration' or 'loop')
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool)
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
@@ -212,22 +145,18 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in the specified node type
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
@@ -261,26 +190,30 @@ class WorkflowBasedAppRunner:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
target_node_config = node
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
@@ -301,44 +234,120 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
||||
388
api/core/datasource/utils/parser.py
Normal file
388
api/core/datasource/utils/parser.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue]
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue]
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType]
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
17
api/core/datasource/utils/text_processing_utils.py
Normal file
17
api/core/datasource/utils/text_processing_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
9
api/core/datasource/utils/uuid_utils.py
Normal file
9
api/core/datasource/utils/uuid_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
43
api/core/datasource/utils/workflow_configuration_sync.py
Normal file
43
api/core/datasource/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
35
api/core/datasource/utils/yaml_utils.py
Normal file
35
api/core/datasource/utils/yaml_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
@@ -205,10 +205,16 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get custom provider record.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
Provider.provider_name.in_(provider_names),
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@@ -270,7 +276,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderCredential.id).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -318,7 +324,7 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
@@ -368,7 +374,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -381,7 +387,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
@@ -417,16 +423,6 @@ class ProviderConfiguration(BaseModel):
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def _get_provider_names(self):
|
||||
"""
|
||||
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
@@ -505,7 +501,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@@ -558,7 +554,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Find all load balancing configs that use this credential_id
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source,
|
||||
)
|
||||
@@ -595,7 +591,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@@ -606,7 +602,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Check if this credential is used in load balancing configs
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
)
|
||||
@@ -628,7 +624,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the provider record
|
||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@@ -672,7 +668,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -741,7 +737,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -788,7 +784,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
@@ -864,7 +860,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1001,7 +997,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1046,7 +1042,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1056,7 +1052,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
)
|
||||
@@ -1079,7 +1075,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the custom model record
|
||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1119,7 +1115,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1161,7 +1157,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1208,9 +1204,15 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
@@ -1382,9 +1384,15 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
def _switch(s: Session):
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
)
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
@@ -65,13 +65,13 @@ class TraceClient:
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
response = httpx.head(self.endpoint, timeout=5)
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
except requests.RequestException as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
|
||||
@@ -417,7 +417,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
logger.info("Weave login successful")
|
||||
print("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
|
||||
@@ -513,21 +513,6 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_names(provider_name: str) -> list[str]:
|
||||
"""
|
||||
provider_name: `openai` or `langgenius/openai/openai`
|
||||
return: [`openai`, `langgenius/openai/openai`]
|
||||
"""
|
||||
provider_names = [provider_name]
|
||||
model_provider_id = ModelProviderID(provider_name)
|
||||
if model_provider_id.is_langgenius():
|
||||
if "/" in provider_name:
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
else:
|
||||
provider_names.append(str(model_provider_id))
|
||||
return provider_names
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
@@ -540,10 +525,7 @@ class ProviderManager:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
@@ -572,7 +554,7 @@ class ProviderManager:
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -10,24 +9,11 @@ from pymochow import MochowClient # type: ignore
|
||||
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from pymochow.configuration import Configuration # type: ignore
|
||||
from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import (
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
InvertedIndex,
|
||||
InvertedIndexAnalyzer,
|
||||
InvertedIndexFieldAttribute,
|
||||
InvertedIndexParams,
|
||||
InvertedIndexParseMode,
|
||||
Schema,
|
||||
VectorIndex,
|
||||
) # type: ignore
|
||||
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field as VDBField
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@@ -36,8 +22,6 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaiduConfig(BaseModel):
|
||||
endpoint: str
|
||||
@@ -46,11 +30,9 @@ class BaiduConfig(BaseModel):
|
||||
api_key: str
|
||||
database: str
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "IP"
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1
|
||||
replicas: int = 3
|
||||
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
|
||||
inverted_index_parser_mode: str = "COARSE_MODE"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -67,9 +49,13 @@ class BaiduConfig(BaseModel):
|
||||
|
||||
|
||||
class BaiduVector(BaseVector):
|
||||
vector_index: str = "vector_idx"
|
||||
filtering_index: str = "filtering_idx"
|
||||
inverted_index: str = "content_inverted_idx"
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
field_text: str = "text"
|
||||
field_metadata: str = "metadata"
|
||||
field_app_id: str = "app_id"
|
||||
field_annotation_id: str = "annotation_id"
|
||||
index_vector: str = "vector_idx"
|
||||
|
||||
def __init__(self, collection_name: str, config: BaiduConfig):
|
||||
super().__init__(collection_name)
|
||||
@@ -88,6 +74,8 @@ class BaiduVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
@@ -96,31 +84,29 @@ class BaiduVector(BaseVector):
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
for i in range(start, end, 1):
|
||||
metadata = documents[i].metadata
|
||||
row = Row(
|
||||
id=metadata.get("doc_id", str(uuid.uuid4())),
|
||||
page_content=documents[i].page_content,
|
||||
metadata=metadata,
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadatas[i]),
|
||||
app_id=metadatas[i].get("app_id", ""),
|
||||
annotation_id=metadatas[i].get("annotation_id", ""),
|
||||
)
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
start_time = time.time()
|
||||
table.rebuild_index(self.index_vector)
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
index = table.describe_index(self.index_vector)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
|
||||
if res and res.code == 0:
|
||||
return True
|
||||
return False
|
||||
@@ -129,73 +115,53 @@ class BaiduVector(BaseVector):
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
# Escape double quotes in value to prevent injection
|
||||
escaped_value = value.replace('"', '\\"')
|
||||
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] IN({document_ids})'
|
||||
anns = AnnSearch(
|
||||
vector_field=VDBField.VECTOR,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)),
|
||||
filter=filter,
|
||||
)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY],
|
||||
retrieve_vector=False,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
retrieve_vector=True,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# document ids filter
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] IN({document_ids})'
|
||||
|
||||
request = BM25SearchRequest(
|
||||
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter
|
||||
)
|
||||
res = self._db.table(self._collection_name).bm25_search(
|
||||
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY]
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold) -> list[Document]:
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
meta = row_data.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
meta = row_data.get(VDBField.METADATA_KEY, {})
|
||||
|
||||
# Handle both JSON string and dict formats for backward compatibility
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
meta = json.loads(meta)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
meta = {}
|
||||
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta)
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
@@ -212,7 +178,7 @@ class BaiduVector(BaseVector):
|
||||
client = MochowClient(config)
|
||||
return client
|
||||
|
||||
def _init_database(self) -> Database:
|
||||
def _init_database(self):
|
||||
exists = False
|
||||
for db in self._client.list_databases():
|
||||
if db.database_name == self._client_config.database:
|
||||
@@ -226,10 +192,10 @@ class BaiduVector(BaseVector):
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
return self._client.database(self._client_config.database)
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
return self._client.database(self._client_config.database)
|
||||
return
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
@@ -266,7 +232,7 @@ class BaiduVector(BaseVector):
|
||||
fields = []
|
||||
fields.append(
|
||||
Field(
|
||||
VDBField.PRIMARY_KEY,
|
||||
self.field_id,
|
||||
FieldType.STRING,
|
||||
primary_key=True,
|
||||
partition_key=True,
|
||||
@@ -274,57 +240,24 @@ class BaiduVector(BaseVector):
|
||||
not_null=True,
|
||||
)
|
||||
)
|
||||
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False))
|
||||
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False))
|
||||
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
|
||||
fields.append(Field(self.field_app_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_annotation_id, FieldType.STRING))
|
||||
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
|
||||
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||
|
||||
# Construct vector index params
|
||||
indexes = []
|
||||
indexes.append(
|
||||
VectorIndex(
|
||||
index_name=self.vector_index,
|
||||
index_name="vector_idx",
|
||||
index_type=index_type,
|
||||
field=VDBField.VECTOR,
|
||||
field="vector",
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
)
|
||||
)
|
||||
|
||||
# Filtering index
|
||||
indexes.append(
|
||||
FilteringIndex(
|
||||
index_name=self.filtering_index,
|
||||
fields=[VDBField.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
# Get analyzer and parse_mode from config
|
||||
analyzer = getattr(
|
||||
InvertedIndexAnalyzer,
|
||||
self._client_config.inverted_index_analyzer,
|
||||
InvertedIndexAnalyzer.DEFAULT_ANALYZER,
|
||||
)
|
||||
|
||||
parse_mode = getattr(
|
||||
InvertedIndexParseMode,
|
||||
self._client_config.inverted_index_parser_mode,
|
||||
InvertedIndexParseMode.COARSE_MODE,
|
||||
)
|
||||
|
||||
# Inverted index
|
||||
indexes.append(
|
||||
InvertedIndex(
|
||||
index_name=self.inverted_index,
|
||||
fields=[VDBField.CONTENT_KEY],
|
||||
params=InvertedIndexParams(
|
||||
analyzer=analyzer,
|
||||
parse_mode=parse_mode,
|
||||
case_sensitive=True,
|
||||
),
|
||||
field_attributes=[InvertedIndexFieldAttribute.ANALYZED],
|
||||
)
|
||||
)
|
||||
|
||||
# Create table
|
||||
self._db.create_table(
|
||||
table_name=self._collection_name,
|
||||
@@ -335,15 +268,11 @@ class BaiduVector(BaseVector):
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
timeout = 300 # 5 minutes timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
@@ -367,7 +296,5 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
|
||||
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
|
||||
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
|
||||
from sqlalchemy import JSON, Column, String
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
@@ -117,39 +117,22 @@ class OceanBaseVector(BaseVector):
|
||||
columns=cols,
|
||||
vidxs=vidx_params,
|
||||
)
|
||||
logger.debug("DEBUG: Table '%s' created successfully", self._collection_name)
|
||||
|
||||
if self._hybrid_search_enabled:
|
||||
# Get parser from config or use default ik parser
|
||||
parser_name = dify_config.OCEANBASE_FULLTEXT_PARSER or "ik"
|
||||
|
||||
allowed_parsers = ["ngram", "beng", "space", "ngram2", "ik", "japanese_ftparser", "thai_ftparser"]
|
||||
if parser_name not in allowed_parsers:
|
||||
raise ValueError(
|
||||
f"Invalid OceanBase full-text parser: {parser_name}. "
|
||||
f"Allowed values are: {', '.join(allowed_parsers)}"
|
||||
try:
|
||||
if self._hybrid_search_enabled:
|
||||
self._client.create_fts_idx_with_fts_index_param(
|
||||
table_name=self._collection_name,
|
||||
fts_idx_param=FtsIndexParam(
|
||||
index_name="fulltext_index_for_col_text",
|
||||
field_names=["text"],
|
||||
parser_type=FtsParser.IK,
|
||||
),
|
||||
)
|
||||
logger.debug("Hybrid search is enabled, parser_name='%s'", parser_name)
|
||||
logger.debug(
|
||||
"About to create fulltext index for collection '%s' using parser '%s'",
|
||||
self._collection_name,
|
||||
parser_name,
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
|
||||
+ "to support fulltext index and vector index in the same table",
|
||||
e,
|
||||
)
|
||||
try:
|
||||
sql_command = f"""ALTER TABLE {self._collection_name}
|
||||
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER {parser_name}"""
|
||||
logger.debug("DEBUG: Executing SQL: %s", sql_command)
|
||||
self._client.perform_raw_text_sql(sql_command)
|
||||
logger.debug("DEBUG: Fulltext index created successfully for '%s'", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while creating fulltext index")
|
||||
raise Exception(
|
||||
"Failed to add fulltext index to the target table, your OceanBase version must be "
|
||||
"4.3.5.1 or above to support fulltext index and vector index in the same table"
|
||||
) from e
|
||||
else:
|
||||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@@ -246,7 +229,7 @@ class OceanBaseVector(BaseVector):
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
print(f"Invalid JSON metadata: {metadata_str}")
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=_text, metadata=metadata))
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import array
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -20,8 +19,6 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
|
||||
@@ -183,8 +180,8 @@ class OracleVector(BaseVector):
|
||||
value,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to insert record %s into %s", value[0], self.table_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
conn.close()
|
||||
return pks
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@@ -24,8 +23,6 @@ from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
@@ -190,8 +187,8 @@ class RelytVector(BaseVector):
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
conn.execute(chunks_table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Delete operation failed for collection %s", self._collection_name)
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
@@ -164,8 +164,8 @@ class TiDBVector(BaseVector):
|
||||
delete_condition = table.c.id.in_(ids)
|
||||
conn.execute(table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Delete operation failed for collection %s", self._collection_name)
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
return False
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
@@ -417,10 +417,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
if db_model is not None:
|
||||
offload_data = db_model.offload_data
|
||||
|
||||
else:
|
||||
db_model = self._to_db_model(domain_model)
|
||||
offload_data = db_model.offload_data
|
||||
offload_data = []
|
||||
|
||||
offload_data = db_model.offload_data
|
||||
if domain_model.inputs is not None:
|
||||
result = self._truncate_and_upload(
|
||||
domain_model.inputs,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
@@ -9,8 +8,6 @@ from typing import Any, ClassVar, Optional
|
||||
class SchemaRegistry:
|
||||
"""Schema registry manages JSON schemas with version support"""
|
||||
|
||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
@@ -86,7 +83,7 @@ class SchemaRegistry:
|
||||
self.metadata[uri] = metadata
|
||||
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
self.logger.warning("Failed to load schema %s/%s: %s", version, schema_name, e)
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
|
||||
def get_schema(self, uri: str) -> Any | None:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
|
||||
@@ -396,10 +396,6 @@ class ApiTool(Tool):
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
|
||||
# FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
|
||||
# We need never break the original flows
|
||||
yield self.create_text_message(response.text)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
|
||||
@@ -41,8 +41,7 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
@@ -104,8 +103,7 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
@@ -174,7 +172,6 @@ class GraphExecution:
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
@@ -198,7 +195,6 @@ class GraphExecution:
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
@@ -209,7 +205,3 @@ class GraphExecution:
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
@@ -3,12 +3,11 @@ Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -123,15 +122,13 @@ class EventHandler:
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
@@ -164,7 +161,7 @@ class EventHandler:
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
@@ -194,7 +191,7 @@ class EventHandler:
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
@@ -210,7 +207,6 @@ class EventHandler:
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
@@ -231,40 +227,10 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
@@ -276,31 +242,21 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in outputs.items():
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
|
||||
@@ -5,7 +5,6 @@ Unified event manager for collecting and emitting events.
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
@@ -52,23 +51,43 @@ class ReadWriteLock:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
def read_lock(self) -> "ReadLockContext":
|
||||
"""Return a context manager for read locking."""
|
||||
self.acquire_read()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_read()
|
||||
return ReadLockContext(self)
|
||||
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
def write_lock(self) -> "WriteLockContext":
|
||||
"""Return a context manager for write locking."""
|
||||
self.acquire_write()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
return WriteLockContext(self)
|
||||
|
||||
|
||||
@final
|
||||
class ReadLockContext:
|
||||
"""Context manager for read locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "ReadLockContext":
|
||||
self._lock.acquire_read()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_read()
|
||||
|
||||
|
||||
@final
|
||||
class WriteLockContext:
|
||||
"""Context manager for write locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "WriteLockContext":
|
||||
self._lock.acquire_write()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_write()
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@@ -23,7 +23,6 @@ from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
@@ -261,23 +260,12 @@ class GraphEngine:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
@@ -128,13 +127,6 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
@@ -146,12 +138,6 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
# Retry before Started because Retry subclasses Started;
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
@@ -181,6 +167,11 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
|
||||
@@ -147,4 +147,4 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command")
|
||||
self.logger.exception("Failed to send abort command: %s")
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
@@ -21,7 +19,6 @@ from core.workflow.enums import (
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
@@ -37,7 +34,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@@ -242,8 +238,6 @@ class IterationNode(Node):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@@ -286,29 +280,26 @@ class IterationNode(Node):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
@@ -381,16 +372,43 @@ class IterationNode(Node):
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
@@ -426,7 +444,9 @@ class IterationNode(Node):
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@@ -465,7 +485,7 @@ class IterationNode(Node):
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
|
||||
@@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel):
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
@@ -128,13 +127,11 @@ class LoopNode(Node):
|
||||
try:
|
||||
reach_break_condition = False
|
||||
if break_conditions:
|
||||
with contextlib.suppress(ValueError):
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
cost_tokens = 0
|
||||
@@ -298,11 +295,42 @@ class LoopNode(Node):
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Get node configs from graph_config
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
loop_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
@@ -343,35 +371,12 @@ class LoopNode(Node):
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
This method statically analyzes the graph configuration to find all nodes
|
||||
that are part of the specified loop, without creating actual node instances.
|
||||
|
||||
:param graph_config: the complete graph configuration
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
loop_node_ids.add(node_id)
|
||||
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
|
||||
@@ -33,7 +33,6 @@ file_fields = {
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"preview_url": fields.String,
|
||||
"source_url": fields.String,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,32 +1,10 @@
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from gevent import events as gevent_events
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# NOTE(QuantumGhost): here we cannot use post_fork to patch gRPC, as
|
||||
# grpc_gevent.init_gevent must be called after patching stdlib.
|
||||
# Gunicorn calls `post_init` before applying monkey patch.
|
||||
# Use `post_init` to setup gRPC gevent support would cause deadlock and
|
||||
# some other weird issues.
|
||||
#
|
||||
# ref:
|
||||
# - https://github.com/grpc/grpc/blob/62533ea13879d6ee95c6fda11ec0826ca822c9dd/src/python/grpcio/grpc/experimental/gevent.py
|
||||
# - https://github.com/gevent/gevent/issues/2060#issuecomment-3016768668
|
||||
# - https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py#L607-L613
|
||||
|
||||
|
||||
def post_patch(event):
|
||||
# this function is only called for gevent worker.
|
||||
# from gevent docs (https://www.gevent.org/api/gevent.monkey.html):
|
||||
# You can also subscribe to the events to provide additional patching beyond what gevent distributes, either for
|
||||
# additional standard library modules, or for third-party packages. The suggested time to do this patching is in
|
||||
# the subscriber for gevent.events.GeventDidPatchBuiltinModulesEvent.
|
||||
if not isinstance(event, gevent_events.GeventDidPatchBuiltinModulesEvent):
|
||||
return
|
||||
def post_fork(server, worker):
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
print("gRPC patched with gevent.", flush=True) # noqa: T201
|
||||
server.log.info("gRPC patched with gevent.")
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
|
||||
|
||||
|
||||
gevent_events.subscribers.append(post_patch)
|
||||
server.log.info("psycopg2 patched with gevent.")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,7 +58,7 @@ class GitHubOAuth(OAuth):
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
@@ -70,11 +70,11 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response = requests.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
|
||||
@@ -112,7 +112,7 @@ class GoogleOAuth(OAuth):
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
@@ -124,7 +124,7 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response = requests.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
@@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
@@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
message = response_json.get("message", "unknown error")
|
||||
@@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response_json = response.json()
|
||||
if "object" in response_json and response_json["object"] == "user":
|
||||
user_type = response_json["type"]
|
||||
@@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
|
||||
@@ -47,7 +47,7 @@ def upgrade():
|
||||
sa.Column('plugin_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('auth_type', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('avatar_url', sa.Text(), nullable=True),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""remove-builtin-template-user
|
||||
|
||||
Revision ID: bf0bcbf45396
|
||||
Revises: 68519ad5cd18
|
||||
Create Date: 2025-09-25 16:50:32.245503
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'bf0bcbf45396'
|
||||
down_revision = '68519ad5cd18'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.drop_column('updated_by')
|
||||
batch_op.drop_column('created_by')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -910,7 +910,7 @@ class AppDatasetJoin(Base):
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
@@ -931,7 +931,7 @@ class DatasetQuery(Base):
|
||||
source_app_id = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetKeywordTable(Base):
|
||||
@@ -1239,6 +1239,15 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
language = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
|
||||
@@ -1044,7 +1044,7 @@ class Message(Base):
|
||||
sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
elif "file-preview" in url:
|
||||
# get upload file id
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview\?timestamp="
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
@@ -1055,7 +1055,7 @@ class Message(Base):
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
elif "image-preview" in url:
|
||||
# image-preview is deprecated, use file-preview instead
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview\?timestamp="
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
@@ -1731,7 +1731,7 @@ class MessageChain(Base):
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
input = mapped_column(sa.Text, nullable=True)
|
||||
output = mapped_column(sa.Text, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAgentThought(Base):
|
||||
@@ -1769,7 +1769,7 @@ class MessageAgentThought(Base):
|
||||
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def files(self) -> list[Any]:
|
||||
@@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base):
|
||||
index_node_hash = mapped_column(sa.Text, nullable=True)
|
||||
retriever_from = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
|
||||
@@ -35,7 +35,7 @@ class DatasourceProvider(Base):
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default")
|
||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
||||
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.9.0"
|
||||
version = "2.0.0-beta2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
"arize-phoenix-otel~=0.9.2",
|
||||
"authlib==1.6.4",
|
||||
"authlib==1.3.1",
|
||||
"azure-identity==1.16.1",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.35.99",
|
||||
@@ -20,7 +20,7 @@ dependencies = [
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gevent~=24.11.1",
|
||||
"gmpy2~=2.2.1",
|
||||
"google-api-core==2.18.0",
|
||||
"google-api-python-client==2.90.0",
|
||||
@@ -169,7 +169,7 @@ dev = [
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy~=1.17.1",
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"locust>=2.40.4",
|
||||
"sseclient-py>=1.8.0",
|
||||
]
|
||||
|
||||
@@ -211,7 +211,7 @@ vdb = [
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
"pymochow==2.2.9",
|
||||
"pymochow==1.3.1",
|
||||
"pyobvector~=0.2.15",
|
||||
"qdrant-client==1.9.0",
|
||||
"tablestore==6.2.0",
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from typing import TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from configs import dify_config
|
||||
@@ -66,7 +65,7 @@ class AppService:
|
||||
return None
|
||||
|
||||
app_models = db.paginate(
|
||||
sa.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args["page"],
|
||||
per_page=args["limit"],
|
||||
error_out=False,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
@@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
@@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
@@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
@@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
|
||||
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
|
||||
|
||||
def _get_request(self, url, headers):
|
||||
return httpx.get(url, headers=headers)
|
||||
return requests.get(url, headers=headers)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
||||
@@ -115,12 +115,12 @@ class DatasetService:
|
||||
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
|
||||
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
|
||||
query = query.where(
|
||||
sa.or_(
|
||||
db.or_(
|
||||
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
|
||||
sa.and_(
|
||||
db.and_(
|
||||
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
|
||||
),
|
||||
sa.and_(
|
||||
db.and_(
|
||||
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
Dataset.id.in_(permitted_dataset_ids),
|
||||
),
|
||||
@@ -128,9 +128,9 @@ class DatasetService:
|
||||
)
|
||||
else:
|
||||
query = query.where(
|
||||
sa.or_(
|
||||
db.or_(
|
||||
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
|
||||
sa.and_(
|
||||
db.and_(
|
||||
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
|
||||
),
|
||||
)
|
||||
@@ -1879,7 +1879,7 @@ class DocumentService:
|
||||
# for notion_info in notion_info_list:
|
||||
# workspace_id = notion_info.workspace_id
|
||||
# data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
# sa.and_(
|
||||
# db.and_(
|
||||
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
# DataSourceOauthBinding.provider == "notion",
|
||||
# DataSourceOauthBinding.disabled == False,
|
||||
|
||||
@@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel):
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
@@ -128,10 +128,3 @@ class KnowledgeConfiguration(BaseModel):
|
||||
if v is None:
|
||||
return ""
|
||||
return v
|
||||
|
||||
|
||||
class PipelineBuiltInTemplateEntity(BaseModel):
|
||||
template_id: str | None = None
|
||||
name: str
|
||||
description: str
|
||||
language: str
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
|
||||
class OperationService:
|
||||
@@ -12,7 +12,7 @@ class OperationService:
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -46,11 +46,7 @@ limit 1000"""
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
logger.debug(
|
||||
"Processing dataset %s with retrieval model of type %s",
|
||||
record_id,
|
||||
type(retrieval_model),
|
||||
)
|
||||
print(type(retrieval_model))
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
@@ -471,7 +471,7 @@ class PluginMigration:
|
||||
total_failed_tenant = 0
|
||||
while True:
|
||||
# paginate
|
||||
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
|
||||
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
|
||||
if tenants.items is None or len(tenants.items) == 0:
|
||||
break
|
||||
|
||||
|
||||
@@ -74,4 +74,5 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"chunk_structure": pipeline_template.chunk_structure,
|
||||
"export_data": pipeline_template.yaml_content,
|
||||
"graph": graph_data,
|
||||
"created_by": pipeline_template.created_user_name,
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ from datetime import UTC, datetime
|
||||
from typing import Any, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@@ -61,7 +60,6 @@ from models.dataset import ( # type: ignore
|
||||
Document,
|
||||
DocumentPipelineExecutionLog,
|
||||
Pipeline,
|
||||
PipelineBuiltInTemplate,
|
||||
PipelineCustomizedTemplate,
|
||||
PipelineRecommendedPlugin,
|
||||
)
|
||||
@@ -78,7 +76,6 @@ from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
PipelineBuiltInTemplateEntity,
|
||||
PipelineTemplateInfoEntity,
|
||||
)
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
@@ -1330,14 +1327,14 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_execution_log = (
|
||||
document_pipeline_excution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
)
|
||||
if not document_pipeline_execution_log:
|
||||
if not document_pipeline_excution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
@@ -1349,10 +1346,10 @@ class RagPipelineService:
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": document_pipeline_execution_log.input_data,
|
||||
"start_node_id": document_pipeline_execution_log.datasource_node_id,
|
||||
"datasource_type": document_pipeline_execution_log.datasource_type,
|
||||
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
|
||||
"inputs": document_pipeline_excution_log.input_data,
|
||||
"start_node_id": document_pipeline_excution_log.datasource_node_id,
|
||||
"datasource_type": document_pipeline_excution_log.datasource_type,
|
||||
"datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)],
|
||||
"original_document_id": document.id,
|
||||
},
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
@@ -1384,8 +1381,8 @@ class RagPipelineService:
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
datasource_plugins = []
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("data", {}).get("type") == "datasource":
|
||||
datasource_node_data = datasource_node["data"]
|
||||
if datasource_node.get("type") == "datasource":
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
if not datasource_node_data:
|
||||
continue
|
||||
|
||||
@@ -1457,140 +1454,3 @@ class RagPipelineService:
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
return pipeline
|
||||
|
||||
def install_built_in_pipeline_template(
|
||||
self, args: PipelineBuiltInTemplateEntity, file_content: str, auth_token: str
|
||||
) -> None:
|
||||
"""
|
||||
Install built-in pipeline template
|
||||
|
||||
Args:
|
||||
args: Pipeline built-in template entity with template metadata
|
||||
file_content: YAML content of the pipeline template
|
||||
auth_token: Authentication token for authorization
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails or template processing errors occur
|
||||
"""
|
||||
# Validate authentication
|
||||
self._validate_auth_token(auth_token)
|
||||
|
||||
# Parse and validate template content
|
||||
pipeline_template_dsl = self._parse_template_content(file_content)
|
||||
|
||||
# Extract template metadata
|
||||
icon = self._extract_icon_metadata(pipeline_template_dsl)
|
||||
chunk_structure = self._extract_chunk_structure(pipeline_template_dsl)
|
||||
|
||||
# Prepare template data
|
||||
template_data = {
|
||||
"name": args.name,
|
||||
"description": args.description,
|
||||
"chunk_structure": chunk_structure,
|
||||
"icon": icon,
|
||||
"language": args.language,
|
||||
"yaml_content": file_content,
|
||||
}
|
||||
|
||||
# Use transaction for database operations
|
||||
try:
|
||||
if args.template_id:
|
||||
self._update_existing_template(args.template_id, template_data)
|
||||
else:
|
||||
self._create_new_template(template_data)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise ValueError(f"Failed to install pipeline template: {str(e)}")
|
||||
|
||||
def _validate_auth_token(self, auth_token: str) -> None:
|
||||
"""Validate the authentication token"""
|
||||
config_auth_token = dify_config.UPLOAD_KNOWLEDGE_PIPELINE_TEMPLATE_TOKEN
|
||||
if not config_auth_token:
|
||||
raise ValueError("Auth token configuration is required")
|
||||
if config_auth_token != auth_token:
|
||||
raise ValueError("Auth token is incorrect")
|
||||
|
||||
def _parse_template_content(self, file_content: str) -> dict:
|
||||
"""Parse and validate YAML template content"""
|
||||
try:
|
||||
pipeline_template_dsl = yaml.safe_load(file_content)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML content: {str(e)}")
|
||||
|
||||
if not pipeline_template_dsl:
|
||||
raise ValueError("Pipeline template DSL is required")
|
||||
|
||||
return pipeline_template_dsl
|
||||
|
||||
def _extract_icon_metadata(self, pipeline_template_dsl: dict) -> dict:
|
||||
"""Extract icon metadata from template DSL"""
|
||||
rag_pipeline_info = pipeline_template_dsl.get("rag_pipeline", {})
|
||||
|
||||
return {
|
||||
"icon": rag_pipeline_info.get("icon", "📙"),
|
||||
"icon_type": rag_pipeline_info.get("icon_type", "emoji"),
|
||||
"icon_background": rag_pipeline_info.get("icon_background", "#FFEAD5"),
|
||||
"icon_url": rag_pipeline_info.get("icon_url"),
|
||||
}
|
||||
|
||||
def _extract_chunk_structure(self, pipeline_template_dsl: dict) -> str:
|
||||
"""Extract chunk structure from template DSL"""
|
||||
nodes = pipeline_template_dsl.get("workflow", {}).get("graph", {}).get("nodes", [])
|
||||
|
||||
# Use generator expression for efficiency
|
||||
chunk_structure = next(
|
||||
(
|
||||
node.get("data", {}).get("chunk_structure")
|
||||
for node in nodes
|
||||
if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_INDEX.value
|
||||
),
|
||||
None
|
||||
)
|
||||
|
||||
if not chunk_structure:
|
||||
raise ValueError("Chunk structure is required in template")
|
||||
|
||||
return chunk_structure
|
||||
|
||||
def _update_existing_template(self, template_id: str, template_data: dict) -> None:
|
||||
"""Update an existing pipeline template"""
|
||||
pipeline_built_in_template = (
|
||||
db.session.query(PipelineBuiltInTemplate)
|
||||
.filter(PipelineBuiltInTemplate.id == template_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline_built_in_template:
|
||||
raise ValueError(f"Pipeline built-in template not found: {template_id}")
|
||||
|
||||
# Update template fields
|
||||
for key, value in template_data.items():
|
||||
setattr(pipeline_built_in_template, key, value)
|
||||
|
||||
db.session.add(pipeline_built_in_template)
|
||||
|
||||
def _create_new_template(self, template_data: dict) -> None:
|
||||
"""Create a new pipeline template"""
|
||||
# Get the next available position
|
||||
position = self._get_next_position(template_data["language"])
|
||||
|
||||
# Add additional fields for new template
|
||||
template_data.update({
|
||||
"position": position,
|
||||
"install_count": 0,
|
||||
"copyright": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_COPYRIGHT,
|
||||
"privacy_policy": dify_config.KNOWLEDGE_PIPELINE_TEMPLATE_PRIVACY_POLICY,
|
||||
})
|
||||
|
||||
new_template = PipelineBuiltInTemplate(**template_data)
|
||||
db.session.add(new_template)
|
||||
|
||||
def _get_next_position(self, language: str) -> int:
|
||||
"""Get the next available position for a template in the specified language"""
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineBuiltInTemplate.position))
|
||||
.filter(PipelineBuiltInTemplate.language == language)
|
||||
.scalar()
|
||||
)
|
||||
return (max_position or 0) + 1
|
||||
|
||||
@@ -685,24 +685,12 @@ class RagPipelineDslService:
|
||||
|
||||
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||
node_data = node.get("data", {})
|
||||
if not node_data:
|
||||
continue
|
||||
data_type = node_data.get("type", "")
|
||||
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node_data.get("dataset_ids", [])
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
# filter credential id from tool node
|
||||
if not include_secret and data_type == NodeType.TOOL.value:
|
||||
node_data.pop("credential_id", None)
|
||||
# filter credential id from agent node
|
||||
if not include_secret and data_type == NodeType.AGENT.value:
|
||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||
tool.pop("credential_id", None)
|
||||
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = self._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
@@ -18,8 +17,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import Knowledge
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineTransformService:
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
@@ -38,11 +35,11 @@ class RagPipelineTransformService:
|
||||
indexing_technique = dataset.indexing_technique
|
||||
|
||||
if not datasource_type and not indexing_technique:
|
||||
return self._transform_to_empty_pipeline(dataset)
|
||||
return self._transfrom_to_empty_pipeline(dataset)
|
||||
|
||||
doc_form = dataset.doc_form
|
||||
if not doc_form:
|
||||
return self._transform_to_empty_pipeline(dataset)
|
||||
return self._transfrom_to_empty_pipeline(dataset)
|
||||
retrieval_model = dataset.retrieval_model
|
||||
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
|
||||
# deal dependencies
|
||||
@@ -260,10 +257,10 @@ class RagPipelineTransformService:
|
||||
if plugin_unique_identifier:
|
||||
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
|
||||
if need_install_plugin_unique_identifiers:
|
||||
logger.debug("Installing missing pipeline plugins %s", need_install_plugin_unique_identifiers)
|
||||
print(need_install_plugin_unique_identifiers)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, need_install_plugin_unique_identifiers)
|
||||
|
||||
def _transform_to_empty_pipeline(self, dataset: Dataset):
|
||||
def _transfrom_to_empty_pipeline(self, dataset: Dataset):
|
||||
pipeline = Pipeline(
|
||||
tenant_id=dataset.tenant_id,
|
||||
name=dataset.name,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -19,7 +18,7 @@ class TagService:
|
||||
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@@ -262,14 +262,6 @@ class VariableTruncator:
|
||||
target_length = self._array_element_limit
|
||||
|
||||
for i, item in enumerate(value):
|
||||
# Dirty fix:
|
||||
# The output of `Start` node may contain list of `File` elements,
|
||||
# causing `AssertionError` while invoking `_truncate_json_primitives`.
|
||||
#
|
||||
# This check ensures that `list[File]` are handled separately
|
||||
if isinstance(item, File):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
if i >= target_length:
|
||||
return _PartResult(truncated_value, used_size, True)
|
||||
if i > 0:
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
@@ -216,7 +216,7 @@ class WebsiteService:
|
||||
@classmethod
|
||||
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
|
||||
if not request.options.crawl_sub_pages:
|
||||
response = httpx.get(
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{request.url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
@@ -224,7 +224,7 @@ class WebsiteService:
|
||||
raise ValueError("Failed to crawl:")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = httpx.post(
|
||||
response = requests.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": request.url,
|
||||
@@ -287,7 +287,7 @@ class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
response = httpx.post(
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
@@ -303,7 +303,7 @@ class WebsiteService:
|
||||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = httpx.post(
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
@@ -362,7 +362,7 @@ class WebsiteService:
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
if not job_id:
|
||||
response = httpx.get(
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
@@ -371,7 +371,7 @@ class WebsiteService:
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = httpx.post(
|
||||
status_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
@@ -381,7 +381,7 @@ class WebsiteService:
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = httpx.post(
|
||||
data_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
|
||||
@@ -450,8 +450,7 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
# plugin does not require credentials, skip
|
||||
return
|
||||
raise ValueError("No default credential found")
|
||||
|
||||
# Check credential policy compliance using the default credential ID
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -52,7 +51,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Integration tests for ChatMessageApi permission verification."""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import completion as completion_api
|
||||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
@@ -101,106 +99,3 @@ class TestChatMessageApiPermissions:
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_get_requires_edit_permission(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Ensure GET chat-messages endpoint enforces edit permissions."""
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
conversation_id = uuid.uuid4()
|
||||
created_at = naive_utc_now()
|
||||
|
||||
mock_conversation = SimpleNamespace(id=str(conversation_id), app_id=str(mock_app_model.id))
|
||||
mock_message = SimpleNamespace(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=str(conversation_id),
|
||||
inputs=[],
|
||||
query="hello",
|
||||
message=[{"text": "hello"}],
|
||||
message_tokens=0,
|
||||
re_sign_file_url_answer="",
|
||||
answer_tokens=0,
|
||||
provider_response_latency=0.0,
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=mock_account.id,
|
||||
feedbacks=[],
|
||||
workflow_run_id=None,
|
||||
annotation=None,
|
||||
annotation_hit_history=None,
|
||||
created_at=created_at,
|
||||
agent_thoughts=[],
|
||||
message_files=[],
|
||||
message_metadata_dict={},
|
||||
status="success",
|
||||
error="",
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
self.data = data
|
||||
self.limit = limit
|
||||
self.has_more = has_more
|
||||
|
||||
monkeypatch.setattr(message_api, "InfiniteScrollPagination", DummyPagination)
|
||||
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
query_string={"conversation_id": str(conversation_id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@@ -27,11 +27,13 @@ class MockedHttp:
|
||||
@classmethod
|
||||
def requests_request(
|
||||
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
Mocked requests.request
|
||||
"""
|
||||
request = httpx.Request(method, url)
|
||||
request = requests.PreparedRequest()
|
||||
request.method = method
|
||||
request.url = url
|
||||
if url.endswith("/tools"):
|
||||
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
|
||||
code=0, message="success", data=cls.list_tools()
|
||||
@@ -39,7 +41,8 @@ class MockedHttp:
|
||||
else:
|
||||
raise ValueError("")
|
||||
|
||||
response = httpx.Response(status_code=200)
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response.request = request
|
||||
response._content = content.encode("utf-8")
|
||||
return response
|
||||
@@ -51,7 +54,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK_SWITCH:
|
||||
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
|
||||
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
|
||||
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
||||
@@ -100,8 +100,8 @@ class MockBaiduVectorDBClass:
|
||||
"row": {
|
||||
"id": primary_key.get("id"),
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
"text": "text",
|
||||
"metadata": '{"doc_id": "doc_id_001"}',
|
||||
},
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
@@ -127,8 +127,8 @@ class MockBaiduVectorDBClass:
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
"text": "text",
|
||||
"metadata": '{"doc_id": "doc_id_001"}',
|
||||
},
|
||||
"distance": 0.1,
|
||||
"score": 0.5,
|
||||
|
||||
@@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_dify_api():
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
response = requests.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
|
||||
@@ -173,7 +173,7 @@ class DifyTestContainers:
|
||||
# Start Dify Plugin Daemon container for plugin management
|
||||
# Dify Plugin Daemon provides plugin lifecycle management and execution
|
||||
logger.info("Initializing Dify Plugin Daemon container...")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.2.0-local")
|
||||
self.dify_plugin_daemon.with_exposed_ports(5002)
|
||||
self.dify_plugin_daemon.env = {
|
||||
"DB_HOST": db_host,
|
||||
|
||||
@@ -13,7 +13,6 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
@@ -203,6 +202,7 @@ class TestBatchCleanDocumentTask:
|
||||
UploadFile: Created upload file instance
|
||||
"""
|
||||
fake = Faker()
|
||||
from datetime import datetime
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
@@ -216,7 +216,7 @@ class TestBatchCleanDocumentTask:
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=naive_utc_now(),
|
||||
created_at=datetime.utcnow(),
|
||||
used=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -201,9 +201,9 @@ class TestOAuthCallback:
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
# Import the real requests module to create a proper exception
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
request_exception = httpx.RequestError("OAuth error")
|
||||
request_exception = requests.exceptions.RequestException("OAuth error")
|
||||
request_exception.response = MagicMock()
|
||||
request_exception.response.text = str(exception)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Unit tests for workflow node execution conflict handling."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
@@ -15,7 +16,6 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestWorkflowNodeExecutionConflictHandling:
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
original_id = execution.id
|
||||
@@ -112,7 +112,7 @@ class TestWorkflowNodeExecutionConflictHandling:
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=naive_utc_now(),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Save should update existing record
|
||||
@@ -157,7 +157,7 @@ class TestWorkflowNodeExecutionConflictHandling:
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Save should raise IntegrityError after max retries
|
||||
@@ -199,7 +199,7 @@ class TestWorkflowNodeExecutionConflictHandling:
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Save should raise error immediately
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
"""Tests for graph engine event handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
|
||||
from core.workflow.graph_engine.event_management.event_manager import EventManager
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
|
||||
|
||||
class _StubEdgeProcessor:
|
||||
"""Minimal edge processor stub for tests."""
|
||||
|
||||
|
||||
class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
def __init__(self, node_id: str) -> None:
|
||||
self.id = node_id
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = "Stub Node"
|
||||
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
"""Construct an EventHandler with in-memory dependencies for testing."""
|
||||
|
||||
node = _StubNode(node_id)
|
||||
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
|
||||
|
||||
variable_pool = VariablePool()
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_execution = GraphExecution(workflow_id="test-workflow")
|
||||
|
||||
event_manager = EventManager()
|
||||
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
|
||||
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
|
||||
|
||||
handler = EventHandler(
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
graph_execution=graph_execution,
|
||||
response_coordinator=response_coordinator,
|
||||
event_collector=event_manager,
|
||||
edge_processor=_StubEdgeProcessor(),
|
||||
state_manager=state_manager,
|
||||
error_handler=_StubErrorHandler(),
|
||||
)
|
||||
|
||||
return handler, event_manager, graph_execution
|
||||
|
||||
|
||||
def test_retry_does_not_emit_additional_start_event() -> None:
|
||||
"""Ensure retry attempts do not produce duplicate start events."""
|
||||
|
||||
node_id = "test-node"
|
||||
handler, event_manager, graph_execution = _build_event_handler(node_id)
|
||||
|
||||
execution_id = "exec-1"
|
||||
node_type = NodeType.CODE
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(start_event)
|
||||
|
||||
retry_event = NodeRunRetryEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
error="boom",
|
||||
retry_index=1,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="boom",
|
||||
error_type="TestError",
|
||||
),
|
||||
)
|
||||
handler.dispatch(retry_event)
|
||||
|
||||
# Simulate the node starting execution again after retry
|
||||
second_start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(second_start_event)
|
||||
|
||||
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
|
||||
|
||||
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
|
||||
|
||||
node_execution = graph_execution.get_or_create_node_execution(node_id)
|
||||
assert node_execution.retry_count == 1
|
||||
@@ -10,18 +10,11 @@ import time
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
|
||||
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
||||
|
||||
# Import the test framework from the new module
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
||||
|
||||
|
||||
@@ -728,39 +721,3 @@ def test_event_sequence_validation_with_table_tests():
|
||||
else:
|
||||
assert result.event_sequence_match is True
|
||||
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
||||
|
||||
|
||||
def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
||||
runner = TableTestRunner()
|
||||
|
||||
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
|
||||
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
|
||||
|
||||
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
query="hello",
|
||||
use_mock_factory=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
llm_node = graph.nodes["llm"]
|
||||
base_node_data = llm_node.get_base_node_data()
|
||||
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
|
||||
|
||||
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
|
||||
assert partial_event.exceptions_count == 1
|
||||
assert partial_event.outputs.get("answer") == "fallback response"
|
||||
|
||||
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
|
||||
|
||||
65
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
65
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value",
|
||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
None,
|
||||
None,
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
@@ -1,8 +1,8 @@
|
||||
import urllib.parse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
|
||||
@@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post")
|
||||
@patch("requests.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
@patch("requests.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
@@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("requests.get")
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
||||
|
||||
with pytest.raises(httpx.RequestError):
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
oauth.get_raw_user_info("test_token")
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post")
|
||||
@patch("requests.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
@patch("requests.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
@@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type",
|
||||
[
|
||||
httpx.HTTPError,
|
||||
httpx.ConnectError,
|
||||
httpx.TimeoutException,
|
||||
requests.exceptions.HTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
@patch("requests.get")
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
import importlib
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import Message
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_file_helpers(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Patch file_helpers.get_signed_file_url to a deterministic stub.
|
||||
"""
|
||||
model_module = importlib.import_module("models.model")
|
||||
dummy = types.SimpleNamespace(get_signed_file_url=lambda fid: f"https://signed.example/{fid}")
|
||||
# Inject/override file_helpers on models.model
|
||||
monkeypatch.setattr(model_module, "file_helpers", dummy, raising=False)
|
||||
|
||||
|
||||
def _wrap_md(url: str) -> str:
|
||||
"""
|
||||
Wrap a raw URL into the markdown that re_sign_file_url_answer expects:
|
||||
[link](<url>)
|
||||
"""
|
||||
return f"please click [file]({url}) to download."
|
||||
|
||||
|
||||
def test_file_preview_valid_replaced():
|
||||
"""
|
||||
Valid file-preview URL must be re-signed:
|
||||
- Extract upload_file_id correctly
|
||||
- Replace the original URL with the signed URL
|
||||
"""
|
||||
upload_id = "abc-123"
|
||||
url = f"/files/{upload_id}/file-preview?timestamp=111&nonce=222&sign=333"
|
||||
msg = Message(answer=_wrap_md(url))
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
assert f"https://signed.example/{upload_id}" in out
|
||||
assert url not in out
|
||||
|
||||
|
||||
def test_file_preview_misspelled_not_replaced():
|
||||
"""
|
||||
Misspelled endpoint 'file-previe?timestamp=' should NOT be rewritten.
|
||||
"""
|
||||
upload_id = "zzz-001"
|
||||
# path deliberately misspelled: file-previe? (missing 'w')
|
||||
# and we append ¬e=file-preview to trick the old `"file-preview" in url` check.
|
||||
url = f"/files/{upload_id}/file-previe?timestamp=111&nonce=222&sign=333¬e=file-preview"
|
||||
original = _wrap_md(url)
|
||||
msg = Message(answer=original)
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
# Expect NO replacement, should not rewrite misspelled file-previe URL
|
||||
assert out == original
|
||||
|
||||
|
||||
def test_image_preview_valid_replaced():
|
||||
"""
|
||||
Valid image-preview URL must be re-signed.
|
||||
"""
|
||||
upload_id = "img-789"
|
||||
url = f"/files/{upload_id}/image-preview?timestamp=123&nonce=456&sign=789"
|
||||
msg = Message(answer=_wrap_md(url))
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
assert f"https://signed.example/{upload_id}" in out
|
||||
assert url not in out
|
||||
|
||||
|
||||
def test_image_preview_misspelled_not_replaced():
|
||||
"""
|
||||
Misspelled endpoint 'image-previe?timestamp=' should NOT be rewritten.
|
||||
"""
|
||||
upload_id = "img-err-42"
|
||||
url = f"/files/{upload_id}/image-previe?timestamp=1&nonce=2&sign=3¬e=image-preview"
|
||||
original = _wrap_md(url)
|
||||
msg = Message(answer=original)
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
# Expect NO replacement, should not rewrite misspelled image-previe URL
|
||||
assert out == original
|
||||
@@ -6,8 +6,8 @@ import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@@ -26,7 +26,7 @@ class TestAuthIntegration:
|
||||
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test complete authentication flow: request → validation → encryption → storage"""
|
||||
@@ -47,7 +47,7 @@ class TestAuthIntegration:
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_cross_component_integration(self, mock_http):
|
||||
"""Test factory → provider → HTTP call integration"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
@@ -97,7 +97,7 @@ class TestAuthIntegration:
|
||||
assert "another_secret" not in factory_str
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test concurrent authentication creation safety"""
|
||||
@@ -142,31 +142,31 @@ class TestAuthIntegration:
|
||||
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
|
||||
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_http_error_handling(self, mock_http):
|
||||
"""Test proper HTTP error handling"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = '{"error": "Unauthorized"}'
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
|
||||
mock_http.return_value = mock_response
|
||||
|
||||
# PT012: Split into single statement for pytest.raises
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
with pytest.raises((httpx.HTTPError, Exception)):
|
||||
with pytest.raises((requests.exceptions.HTTPError, Exception)):
|
||||
factory.validate_credentials()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_network_failure_recovery(self, mock_http, mock_session):
|
||||
"""Test system recovery from network failures"""
|
||||
mock_http.side_effect = httpx.RequestError("Network timeout")
|
||||
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
|
||||
with pytest.raises(httpx.RequestError):
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@@ -95,7 +95,7 @@ class TestFirecrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@@ -115,7 +115,7 @@ class TestFirecrawlAuth:
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@@ -134,13 +134,13 @@ class TestFirecrawlAuth:
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
@@ -162,7 +162,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
@@ -179,12 +179,12 @@ class TestFirecrawlAuth:
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestJinaAuth:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@@ -53,7 +53,7 @@ class TestJinaAuth:
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
@@ -68,7 +68,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
@@ -83,7 +83,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
@@ -98,7 +98,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
@@ -114,7 +114,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
@@ -130,15 +130,15 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
mock_post.side_effect = requests.ConnectionError("Network error")
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
with pytest.raises(requests.ConnectionError):
|
||||
auth.validate_credentials()
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@@ -87,7 +87,7 @@ class TestWatercrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@@ -107,7 +107,7 @@ class TestWatercrawlAuth:
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@@ -126,13 +126,13 @@ class TestWatercrawlAuth:
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
@@ -154,7 +154,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
@@ -179,7 +179,7 @@ class TestWatercrawlAuth:
|
||||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
@@ -193,12 +193,12 @@ class TestWatercrawlAuth:
|
||||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user