mirror of
https://github.com/langgenius/dify.git
synced 2026-04-07 18:45:11 +08:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acfd34e876 | ||
|
|
036a7cf839 | ||
|
|
86beacc64f | ||
|
|
2c6bd90d6f | ||
|
|
f5aaa8f97e | ||
|
|
a22cc5bc5e | ||
|
|
1fbdf6b465 | ||
|
|
491e1fd6a4 | ||
|
|
0e33dfb5c2 | ||
|
|
ea708e7a32 | ||
|
|
c09e29c3f8 | ||
|
|
2d53ba8671 | ||
|
|
9be863fefa | ||
|
|
8f43629cd8 | ||
|
|
9ee71902c1 | ||
|
|
a012c87445 | ||
|
|
450578d4c0 | ||
|
|
837237aa6d | ||
|
|
b63dfbf654 | ||
|
|
51ea87ab85 | ||
|
|
00698e41b7 | ||
|
|
df938a4543 | ||
|
|
9161936f41 | ||
|
|
f9a21b56ab | ||
|
|
220e1df847 | ||
|
|
8cfdde594c | ||
|
|
31a8fd810c | ||
|
|
9fad97ec9b | ||
|
|
0c2729d9b3 | ||
|
|
a2e03b811e | ||
|
|
1e10bf525c | ||
|
|
8b1af36d94 | ||
|
|
0711dd4159 | ||
|
|
ae0a26f5b6 |
@@ -5,5 +5,18 @@
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"ralph-loop@claude-plugins-official": true
|
||||
},
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "Bash",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "npx -y block-no-verify@1.1.1"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@@ -39,12 +39,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev pyrefly
|
||||
uv run pyrefly check || true
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
|
||||
29
.github/workflows/deploy-hitl.yml
vendored
Normal file
29
.github/workflows/deploy-hitl.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Deploy HITL
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "feat/hitl-frontend"
|
||||
- "feat/hitl-backend"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
(
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
|
||||
)
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.HITL_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
8
.github/workflows/tool-test-sdks.yaml
vendored
8
.github/workflows/tool-test-sdks.yaml
vendored
@@ -16,10 +16,6 @@ jobs:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20, 22]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
@@ -29,10 +25,10 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
node-version: 24
|
||||
cache: ''
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
|
||||
@@ -589,6 +589,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
|
||||
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
|
||||
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
|
||||
ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
@@ -34,7 +35,7 @@ from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
@@ -45,6 +46,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -62,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
@@ -84,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@@ -100,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@@ -658,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
email = email.strip().lower()
|
||||
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Invalid email address.", fg="red"))
|
||||
@@ -852,6 +858,61 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
|
||||
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
|
||||
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
days: int,
|
||||
batch_size: int,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow run cleanup completed. start={start_time.isoformat()} "
|
||||
f"end={end_time.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
|
||||
def clear_orphaned_file_records(force: bool):
|
||||
|
||||
@@ -1101,6 +1101,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable clean messages task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
@@ -100,11 +99,12 @@ class ActivateApi(Resource):
|
||||
def post(self):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
normalized_request_email = args.email.lower() if args.email else None
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args.name
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
|
||||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
@@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
AccountService.add_email_register_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
|
||||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_email_register_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
@@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email, password) -> Account | None:
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
|
||||
@@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
@@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
@@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
|
||||
def post(self):
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
@@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
token_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
@@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
||||
@@ -90,32 +90,38 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
request_email = args.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if args.invite_token:
|
||||
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
if invitation_data is None:
|
||||
invite_token = None
|
||||
|
||||
try:
|
||||
if invitation_data:
|
||||
data = invitation_data.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
@@ -130,7 +136,7 @@ class LoginApi(Resource):
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
@@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
@@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
@@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
original_email = args.email
|
||||
user_email = original_email.lower()
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args.email:
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
@@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
@@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
@@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
account = AccountService.get_user_through_email(email)
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return AccountService.get_user_through_email(email.lower())
|
||||
|
||||
|
||||
def _authenticate_account_with_case_fallback(
|
||||
original_email: str, normalized_email: str, password: str, invite_token: str | None
|
||||
):
|
||||
try:
|
||||
return AccountService.authenticate(original_email, password, invite_token)
|
||||
except services.errors.account.AccountPasswordError:
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@@ -118,7 +117,10 @@ class OAuthCallback(Resource):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
invitation_email_normalized = (
|
||||
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||
)
|
||||
if invitation_email_normalized != user_info.email.lower():
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
@@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
@@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
normalized_email = user_info.email.lower()
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
"This email account has been deleted within the past "
|
||||
@@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
||||
@@ -84,10 +84,11 @@ class SetupApi(Resource):
|
||||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
|
||||
@@ -41,7 +41,7 @@ from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = args.email
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
@@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email != current_user.email:
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
raise InvalidEmailError()
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
user_email = account.email
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
|
||||
account=account,
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
|
||||
if is_change_email_error_rate_limit:
|
||||
raise EmailChangeLimitError()
|
||||
|
||||
@@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args.email)
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
@@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
|
||||
if AccountService.is_account_in_freeze(args.new_email):
|
||||
if AccountService.is_account_in_freeze(normalized_new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(args.new_email):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
@@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email != old_email:
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=args.new_email,
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
@@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = CheckEmailUniquePayload.model_validate(payload)
|
||||
if AccountService.is_account_in_freeze(args.email):
|
||||
normalized_email = args.email.lower()
|
||||
if AccountService.is_account_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args.email):
|
||||
if not AccountService.check_email_unique(normalized_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
normalized_invitee_email = invitee_email.lower()
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
encoded_invitee_email = parse.quote(normalized_invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"email": normalized_invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
||||
@@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -22,7 +21,7 @@ from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
def post(self):
|
||||
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
request_email = payload.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
@@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
|
||||
def post(self):
|
||||
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = payload.email
|
||||
user_email = payload.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
@@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if payload.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
@@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
token_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
@@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
||||
@@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args["email"].lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
tool_process_data=None,
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str="{}",
|
||||
@@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_unit_price=Decimal(0),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
answer_unit_price=Decimal(0),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
total_price=Decimal(0),
|
||||
position=self.agent_thought_count + 1,
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
@@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
if thought:
|
||||
agent_thought.thought += thought
|
||||
existing_thought = agent_thought.thought or ""
|
||||
agent_thought.thought = f"{existing_thought}{thought}"
|
||||
|
||||
if tool_name:
|
||||
agent_thought.tool = tool_name
|
||||
@@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(";")
|
||||
tool_names_raw = agent_thought.tool
|
||||
if tool_names_raw:
|
||||
tool_names = tool_names_raw.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
tool_input_payload = agent_thought.tool_input
|
||||
if tool_input_payload:
|
||||
try:
|
||||
tool_inputs = json.loads(tool_input_payload)
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tool_names}
|
||||
else:
|
||||
tool_inputs = {tool: {} for tool in tool_names}
|
||||
|
||||
for tool in tools:
|
||||
observation_payload = agent_thought.observation
|
||||
if observation_payload:
|
||||
try:
|
||||
tool_responses = json.loads(observation_payload)
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tool_names, observation_payload)
|
||||
else:
|
||||
tool_responses = dict.fromkeys(tool_names, observation_payload)
|
||||
|
||||
for tool in tool_names:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
@@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
if not tool_names_raw:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
|
||||
@@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
@@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: str | None = Field(default=None)
|
||||
json_schema: dict | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
def validate_json_schema(cls, schema: dict | None) -> dict | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AdvancedChatAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@@ -358,25 +358,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
||||
# For ANSWER nodes, check if we need to send a message_replace event
|
||||
# Only send if the final output differs from the accumulated task_state.answer
|
||||
# This happens when variables were updated by variable_assigner during workflow execution
|
||||
if event.node_type == NodeType.ANSWER and event.outputs:
|
||||
final_answer = event.outputs.get("answer")
|
||||
if final_answer is not None and final_answer != self._task_state.answer:
|
||||
logger.info(
|
||||
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
|
||||
final_answer,
|
||||
self._task_state.answer,
|
||||
)
|
||||
# Update the task state answer
|
||||
self._task_state.answer = str(final_answer)
|
||||
# Send message_replace event to update the UI
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=str(final_answer),
|
||||
reason="variable_update",
|
||||
)
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
@@ -76,12 +75,24 @@ class BaseAppGenerator:
|
||||
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||
|
||||
# Check if all files are converted to File
|
||||
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||
raise ValueError("Invalid input type")
|
||||
if any(
|
||||
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||
):
|
||||
raise ValueError("Invalid input type")
|
||||
invalid_dict_keys = [
|
||||
k
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict)
|
||||
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
|
||||
]
|
||||
if invalid_dict_keys:
|
||||
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
|
||||
|
||||
invalid_list_dict_keys = [
|
||||
k
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
and any(isinstance(item, dict) for item in v)
|
||||
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
|
||||
]
|
||||
if invalid_list_dict_keys:
|
||||
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
|
||||
|
||||
return user_inputs
|
||||
|
||||
@@ -178,12 +189,8 @@ class BaseAppGenerator:
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
if value and not isinstance(value, dict):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
||||
@@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_calls:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_calls
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class SpanBuilder:
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
kind=span_data.span_kind,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
|
||||
|
||||
@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
@classmethod
|
||||
def ended_values(cls) -> list[str]:
|
||||
return [status.value for status in _END_STATE]
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
@@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
|
||||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
# If no value provided, skip further processing for this key
|
||||
if not value:
|
||||
continue
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"JSON object for '{key}' must be an object")
|
||||
|
||||
# Overwrite with normalized dict to ensure downstream consistency
|
||||
node_inputs[key] = value
|
||||
|
||||
# If schema exists, then validate against it
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
Draft7Validator(schema).validate(value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = json_value
|
||||
|
||||
@@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
||||
Returns True if this node updates any of the requested conversation variables.
|
||||
"""
|
||||
assigned_selector = tuple(self.node_data.assigned_variable_selector)
|
||||
return assigned_selector in variable_selectors
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@@ -136,13 +137,11 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -158,12 +157,12 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
||||
@@ -163,6 +163,13 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
|
||||
"schedule": crontab(minute="0", hour="2"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
|
||||
# for saas only
|
||||
imports.append("schedule.clean_workflow_runs_task")
|
||||
beat_schedule["clean_workflow_runs_task"] = {
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@@ -4,6 +4,7 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
clean_workflow_runs,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
@@ -56,6 +57,7 @@ def init_app(app: DifyApp):
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
clean_workflow_runs,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TypeAlias
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
@@ -20,8 +21,8 @@ class SimpleFeedback(ResponseModel):
|
||||
|
||||
|
||||
class RetrieverResource(ResponseModel):
|
||||
id: str
|
||||
message_id: str
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
message_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
position: int
|
||||
dataset_id: str | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add workflow_run_created_at_id_idx
|
||||
|
||||
Revision ID: 905527cc8fd3
|
||||
Revises: 7df29de0f6be
|
||||
Create Date: 2025-01-09 16:30:02.462084
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '905527cc8fd3'
|
||||
down_revision = '7df29de0f6be'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_run_created_at_id_idx')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1843,7 +1843,7 @@ class MessageChain(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
class MessageAgentThought(Base):
|
||||
class MessageAgentThought(TypeBase):
|
||||
__tablename__ = "message_agent_thoughts"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
@@ -1851,34 +1851,42 @@ class MessageAgentThought(Base):
|
||||
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
message_chain_id = mapped_column(StringUUID, nullable=True)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
thought = mapped_column(LongText, nullable=True)
|
||||
tool = mapped_column(LongText, nullable=True)
|
||||
tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
|
||||
tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
|
||||
tool_input = mapped_column(LongText, nullable=True)
|
||||
observation = mapped_column(LongText, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
|
||||
tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
|
||||
tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
|
||||
tool_process_data = mapped_column(LongText, nullable=True)
|
||||
message = mapped_column(LongText, nullable=True)
|
||||
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
message_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
message_files = mapped_column(LongText, nullable=True)
|
||||
answer = mapped_column(LongText, nullable=True)
|
||||
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
total_price = mapped_column(sa.Numeric, nullable=True)
|
||||
currency = mapped_column(String(255), nullable=True)
|
||||
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
created_by_role = mapped_column(String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
|
||||
message_price_unit: Mapped[Decimal] = mapped_column(
|
||||
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
|
||||
)
|
||||
message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
|
||||
answer_price_unit: Mapped[Decimal] = mapped_column(
|
||||
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
|
||||
)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
|
||||
currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def files(self) -> list[Any]:
|
||||
|
||||
@@ -597,6 +597,7 @@ class WorkflowRun(Base):
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
|
||||
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
|
||||
sa.Index("workflow_run_created_at_id_idx", "created_at", "id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.11.2"
|
||||
version = "1.11.4"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@@ -189,7 +189,7 @@ storage = [
|
||||
"opendal~=0.46.0",
|
||||
"oss2==2.18.5",
|
||||
"supabase~=2.18.1",
|
||||
"tos~=2.7.1",
|
||||
"tos~=2.9.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
||||
@@ -34,11 +34,14 @@ Example:
|
||||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -253,6 +256,44 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_runs_batch_by_time_range(
|
||||
self,
|
||||
start_from: datetime | None,
|
||||
end_before: datetime,
|
||||
last_seen: tuple[datetime, str] | None,
|
||||
batch_size: int,
|
||||
run_types: Sequence[WorkflowType] | None = None,
|
||||
tenant_ids: Sequence[str] | None = None,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Fetch ended workflow runs in a time window for archival and clean batching.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
|
||||
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Delete workflow runs and their related records (node executions, offloads, app logs,
|
||||
trigger logs, pauses, pause reasons).
|
||||
"""
|
||||
...
|
||||
|
||||
def count_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
|
||||
count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Count workflow runs and their related records (node executions, offloads, app logs,
|
||||
trigger logs, pauses, pause reasons) without deleting data.
|
||||
"""
|
||||
...
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
@@ -7,13 +7,18 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, select
|
||||
from sqlalchemy import asc, delete, desc, func, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
@@ -44,6 +49,26 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
@staticmethod
|
||||
def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str:
|
||||
"""
|
||||
Map workflow run triggered_from values to workflow node execution triggered_from values.
|
||||
"""
|
||||
if triggered_from in {
|
||||
WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
WorkflowRunTriggeredFrom.SCHEDULE.value,
|
||||
WorkflowRunTriggeredFrom.PLUGIN.value,
|
||||
WorkflowRunTriggeredFrom.WEBHOOK.value,
|
||||
}:
|
||||
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
if triggered_from in {
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
||||
}:
|
||||
return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value
|
||||
return ""
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -290,3 +315,119 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
result = cast(CursorResult, session.execute(stmt))
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
|
||||
class RunContext(TypedDict):
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
triggered_from: str
|
||||
|
||||
@staticmethod
|
||||
def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
|
||||
"""
|
||||
Delete node executions (and offloads) for the given workflow runs using indexed columns.
|
||||
|
||||
Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id)
|
||||
by filtering on those columns with tuple IN.
|
||||
"""
|
||||
if not runs:
|
||||
return 0, 0
|
||||
|
||||
tuple_values = [
|
||||
(
|
||||
run["tenant_id"],
|
||||
run["app_id"],
|
||||
run["workflow_id"],
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
|
||||
run["triggered_from"]
|
||||
),
|
||||
run["run_id"],
|
||||
)
|
||||
for run in runs
|
||||
]
|
||||
|
||||
node_execution_ids = session.scalars(
|
||||
select(WorkflowNodeExecutionModel.id).where(
|
||||
tuple_(
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id,
|
||||
).in_(tuple_values)
|
||||
)
|
||||
).all()
|
||||
|
||||
if not node_execution_ids:
|
||||
return 0, 0
|
||||
|
||||
offloads_deleted = (
|
||||
cast(
|
||||
CursorResult,
|
||||
session.execute(
|
||||
delete(WorkflowNodeExecutionOffload).where(
|
||||
WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
or 0
|
||||
)
|
||||
|
||||
node_executions_deleted = (
|
||||
cast(
|
||||
CursorResult,
|
||||
session.execute(
|
||||
delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
|
||||
),
|
||||
).rowcount
|
||||
or 0
|
||||
)
|
||||
|
||||
return node_executions_deleted, offloads_deleted
|
||||
|
||||
@staticmethod
|
||||
def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
|
||||
"""
|
||||
Count node executions (and offloads) for the given workflow runs using indexed columns.
|
||||
"""
|
||||
if not runs:
|
||||
return 0, 0
|
||||
|
||||
tuple_values = [
|
||||
(
|
||||
run["tenant_id"],
|
||||
run["app_id"],
|
||||
run["workflow_id"],
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
|
||||
run["triggered_from"]
|
||||
),
|
||||
run["run_id"],
|
||||
)
|
||||
for run in runs
|
||||
]
|
||||
tuple_filter = tuple_(
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id,
|
||||
).in_(tuple_values)
|
||||
|
||||
node_executions_count = (
|
||||
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0
|
||||
)
|
||||
offloads_count = (
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowNodeExecutionOffload)
|
||||
.join(
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id,
|
||||
)
|
||||
.where(tuple_filter)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
return int(node_executions_count), int(offloads_count)
|
||||
|
||||
@@ -21,7 +21,7 @@ Implementation Notes:
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
@@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import convert_datetime_to_date
|
||||
@@ -40,8 +40,14 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowPauseReason, WorkflowRun
|
||||
from models.workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowPauseReason,
|
||||
WorkflowRun,
|
||||
)
|
||||
from models.workflow import (
|
||||
WorkflowPause as WorkflowPauseModel,
|
||||
)
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
@@ -314,6 +320,171 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||
return total_deleted
|
||||
|
||||
def get_runs_batch_by_time_range(
|
||||
self,
|
||||
start_from: datetime | None,
|
||||
end_before: datetime,
|
||||
last_seen: tuple[datetime, str] | None,
|
||||
batch_size: int,
|
||||
run_types: Sequence[WorkflowType] | None = None,
|
||||
tenant_ids: Sequence[str] | None = None,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Fetch ended workflow runs in a time window for archival and clean batching.
|
||||
|
||||
Query scope:
|
||||
- created_at in [start_from, end_before)
|
||||
- type in run_types (when provided)
|
||||
- status is an ended state
|
||||
- optional tenant_id filter and cursor (last_seen) for pagination
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
stmt = (
|
||||
select(WorkflowRun)
|
||||
.where(
|
||||
WorkflowRun.created_at < end_before,
|
||||
WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
|
||||
)
|
||||
.order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
|
||||
.limit(batch_size)
|
||||
)
|
||||
if run_types is not None:
|
||||
if not run_types:
|
||||
return []
|
||||
stmt = stmt.where(WorkflowRun.type.in_(run_types))
|
||||
|
||||
if start_from:
|
||||
stmt = stmt.where(WorkflowRun.created_at >= start_from)
|
||||
|
||||
if tenant_ids:
|
||||
stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids))
|
||||
|
||||
if last_seen:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
WorkflowRun.created_at > last_seen[0],
|
||||
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
|
||||
)
|
||||
)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
|
||||
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if not runs:
|
||||
return {
|
||||
"runs": 0,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
|
||||
with self._session_maker() as session:
|
||||
run_ids = [run.id for run in runs]
|
||||
if delete_node_executions:
|
||||
node_executions_deleted, offloads_deleted = delete_node_executions(session, runs)
|
||||
else:
|
||||
node_executions_deleted, offloads_deleted = 0, 0
|
||||
|
||||
app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
|
||||
app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
|
||||
|
||||
pause_ids = session.scalars(
|
||||
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
|
||||
).all()
|
||||
pause_reasons_deleted = 0
|
||||
pauses_deleted = 0
|
||||
|
||||
if pause_ids:
|
||||
pause_reasons_result = session.execute(
|
||||
delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
|
||||
)
|
||||
pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
|
||||
pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids)))
|
||||
pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
|
||||
|
||||
trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
|
||||
|
||||
runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
|
||||
runs_deleted = cast(CursorResult, runs_result).rowcount or 0
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"runs": runs_deleted,
|
||||
"node_executions": node_executions_deleted,
|
||||
"offloads": offloads_deleted,
|
||||
"app_logs": app_logs_deleted,
|
||||
"trigger_logs": trigger_logs_deleted,
|
||||
"pauses": pauses_deleted,
|
||||
"pause_reasons": pause_reasons_deleted,
|
||||
}
|
||||
|
||||
def count_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
|
||||
count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if not runs:
|
||||
return {
|
||||
"runs": 0,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
|
||||
with self._session_maker() as session:
|
||||
run_ids = [run.id for run in runs]
|
||||
if count_node_executions:
|
||||
node_executions_count, offloads_count = count_node_executions(session, runs)
|
||||
else:
|
||||
node_executions_count, offloads_count = 0, 0
|
||||
|
||||
app_logs_count = (
|
||||
session.scalar(
|
||||
select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
pause_ids = session.scalars(
|
||||
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
|
||||
).all()
|
||||
pauses_count = len(pause_ids)
|
||||
pause_reasons_count = 0
|
||||
if pause_ids:
|
||||
pause_reasons_count = (
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowPauseReason)
|
||||
.where(WorkflowPauseReason.pause_id.in_(pause_ids))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0
|
||||
|
||||
return {
|
||||
"runs": len(runs),
|
||||
"node_executions": node_executions_count,
|
||||
"offloads": offloads_count,
|
||||
"app_logs": int(app_logs_count),
|
||||
"trigger_logs": trigger_logs_count,
|
||||
"pauses": pauses_count,
|
||||
"pause_reasons": int(pause_reasons_count),
|
||||
}
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
@@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy import and_, delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
@@ -84,3 +86,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
)
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs associated with the given workflow run ids.
|
||||
|
||||
Args:
|
||||
run_ids: Collection of workflow run identifiers.
|
||||
|
||||
Returns:
|
||||
Number of rows deleted.
|
||||
"""
|
||||
if not run_ids:
|
||||
return 0
|
||||
|
||||
result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)))
|
||||
return cast(CursorResult, result).rowcount or 0
|
||||
|
||||
def count_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Count trigger logs associated with the given workflow run ids.
|
||||
|
||||
Args:
|
||||
run_ids: Collection of workflow run identifiers.
|
||||
|
||||
Returns:
|
||||
Number of rows matched.
|
||||
"""
|
||||
if not run_ids:
|
||||
return 0
|
||||
|
||||
count = self.session.scalar(
|
||||
select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
|
||||
)
|
||||
return int(count or 0)
|
||||
|
||||
@@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol):
|
||||
A sequence of recent WorkflowTriggerLog instances
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs for workflow run IDs.
|
||||
|
||||
Args:
|
||||
run_ids: Workflow run IDs to delete
|
||||
|
||||
Returns:
|
||||
Number of rows deleted
|
||||
"""
|
||||
...
|
||||
|
||||
43
api/schedule/clean_workflow_runs_task.py
Normal file
43
api/schedule/clean_workflow_runs_task.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import click
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
|
||||
|
||||
@app.celery.task(queue="retention")
|
||||
def clean_workflow_runs_task() -> None:
|
||||
"""
|
||||
Scheduled cleanup for workflow runs and related records (sandbox tenants only).
|
||||
"""
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
"Scheduled workflow run cleanup starting: "
|
||||
f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, "
|
||||
f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}"
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
|
||||
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
|
||||
start_from=None,
|
||||
end_before=None,
|
||||
).run()
|
||||
|
||||
end_time = datetime.now(UTC)
|
||||
elapsed = end_time - start_time
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
|
||||
f"end={end_time.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@@ -8,7 +8,7 @@ from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@@ -748,6 +748,21 @@ class AccountService:
|
||||
cls.email_code_login_rate_limiter.increment_rate_limit(email)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
query_session = session or db.session
|
||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "email_code_login")
|
||||
@@ -1363,16 +1378,22 @@ class RegisterService:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
normalized_email = email.lower()
|
||||
|
||||
"""Invite new member"""
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
name = email.split("@")[0]
|
||||
name = normalized_email.split("@")[0]
|
||||
|
||||
account = cls.register(
|
||||
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
|
||||
email=normalized_email,
|
||||
name=name,
|
||||
language=language,
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
# Create new tenant member for invited tenant
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@@ -1394,7 +1415,7 @@ class RegisterService:
|
||||
# send email
|
||||
send_invite_member_mail_task.delay(
|
||||
language=language,
|
||||
to=email,
|
||||
to=account.email,
|
||||
token=token,
|
||||
inviter_name=inviter.name if inviter else "Dify",
|
||||
workspace_name=tenant.name,
|
||||
@@ -1493,6 +1514,16 @@ class RegisterService:
|
||||
invitation: dict = json.loads(data)
|
||||
return invitation
|
||||
|
||||
@classmethod
|
||||
def get_invitation_with_case_fallback(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
) -> dict[str, Any] | None:
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
|
||||
if invitation or not email or email == email.lower():
|
||||
return invitation
|
||||
normalized_email = email.lower()
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
|
||||
|
||||
|
||||
def _generate_refresh_token(length: int = 64):
|
||||
token = secrets.token_hex(length)
|
||||
|
||||
0
api/services/retention/__init__.py
Normal file
0
api/services/retention/__init__.py
Normal file
0
api/services/retention/workflow_run/__init__.py
Normal file
0
api/services/retention/workflow_run/__init__.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
days: int,
|
||||
batch_size: int,
|
||||
start_from: datetime.datetime | None = None,
|
||||
end_before: datetime.datetime | None = None,
|
||||
workflow_run_repo: APIWorkflowRunRepository | None = None,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise ValueError("start_from and end_before must be both set or both omitted.")
|
||||
|
||||
computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
self.window_start = start_from
|
||||
self.window_end = end_before or computed_cutoff
|
||||
|
||||
if self.window_start and self.window_end <= self.window_start:
|
||||
raise ValueError("end_before must be greater than start_from.")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be greater than 0.")
|
||||
|
||||
self.batch_size = batch_size
|
||||
self._cleanup_whitelist: set[str] | None = None
|
||||
self.dry_run = dry_run
|
||||
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
|
||||
self.workflow_run_repo: APIWorkflowRunRepository
|
||||
if workflow_run_repo:
|
||||
self.workflow_run_repo = workflow_run_repo
|
||||
else:
|
||||
# Lazy import to avoid circular dependencies during module import
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
def run(self) -> None:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs "
|
||||
f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}"
|
||||
f"{self.window_end.isoformat()} (batch={self.batch_size})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
if self.dry_run:
|
||||
click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow"))
|
||||
|
||||
total_runs_deleted = 0
|
||||
total_runs_targeted = 0
|
||||
related_totals = self._empty_related_counts() if self.dry_run else None
|
||||
batch_index = 0
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
|
||||
while True:
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if not run_rows:
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
|
||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||
tenant_id_list = list(tenant_ids)
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return set(tenant_id_list)
|
||||
|
||||
if not tenant_id_list:
|
||||
return set()
|
||||
|
||||
cleanup_whitelist = self._get_cleanup_whitelist()
|
||||
|
||||
try:
|
||||
bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list)
|
||||
except Exception:
|
||||
bulk_info = {}
|
||||
logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list)
|
||||
|
||||
eligible_free_tenants: set[str] = set()
|
||||
for tenant_id in tenant_id_list:
|
||||
if tenant_id in cleanup_whitelist:
|
||||
continue
|
||||
|
||||
info = bulk_info.get(tenant_id)
|
||||
if info is None:
|
||||
logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
|
||||
continue
|
||||
|
||||
if info.get("plan") != CloudPlan.SANDBOX:
|
||||
continue
|
||||
|
||||
if self._is_within_grace_period(tenant_id, info):
|
||||
continue
|
||||
|
||||
eligible_free_tenants.add(tenant_id)
|
||||
|
||||
return eligible_free_tenants
|
||||
|
||||
def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None:
|
||||
if expiration_value < 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC)
|
||||
except (OverflowError, OSError, ValueError):
|
||||
logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id)
|
||||
return None
|
||||
|
||||
def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool:
|
||||
if self.free_plan_grace_period_days <= 0:
|
||||
return False
|
||||
|
||||
expiration_value = info.get("expiration_date", -1)
|
||||
expiration_at = self._expiration_datetime(tenant_id, expiration_value)
|
||||
if expiration_at is None:
|
||||
return False
|
||||
|
||||
grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days)
|
||||
return datetime.datetime.now(datetime.UTC) < grace_deadline
|
||||
|
||||
def _get_cleanup_whitelist(self) -> set[str]:
|
||||
if self._cleanup_whitelist is not None:
|
||||
return self._cleanup_whitelist
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
self._cleanup_whitelist = set()
|
||||
return self._cleanup_whitelist
|
||||
|
||||
try:
|
||||
whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch cleanup whitelist from billing service")
|
||||
whitelist_ids = []
|
||||
|
||||
self._cleanup_whitelist = set(whitelist_ids)
|
||||
return self._cleanup_whitelist
|
||||
|
||||
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.delete_by_run_ids(run_ids)
|
||||
|
||||
def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.count_by_run_ids(run_ids)
|
||||
|
||||
@staticmethod
|
||||
def _build_run_contexts(
|
||||
runs: Sequence[WorkflowRun],
|
||||
) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]:
|
||||
return [
|
||||
{
|
||||
"run_id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"app_id": run.app_id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"triggered_from": run.triggered_from,
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _empty_related_counts() -> dict[str, int]:
|
||||
return {
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_related_counts(counts: dict[str, int]) -> str:
|
||||
return (
|
||||
f"node_executions {counts['node_executions']}, "
|
||||
f"offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, "
|
||||
f"trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, "
|
||||
f"pause_reasons {counts['pause_reasons']}"
|
||||
)
|
||||
|
||||
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_contexts = self._build_run_contexts(runs)
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts)
|
||||
|
||||
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_contexts = self._build_run_contexts(runs)
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts)
|
||||
@@ -12,6 +12,7 @@ from libs.passport import PassportService
|
||||
from libs.password import compare_password
|
||||
from models import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
@@ -32,7 +33,7 @@ class WebAppAuthService:
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
@@ -52,7 +53,7 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
account = db.session.query(Account).where(Account.email == email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
|
||||
158
api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml
vendored
Normal file
158
api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml
vendored
Normal file
@@ -0,0 +1,158 @@
|
||||
app:
|
||||
description: Validate v1 Variable Assigner blocks streaming until conversation variable is updated.
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: test_streaming_conversation_variables_v1_overwrite
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies: []
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables:
|
||||
- description: ''
|
||||
id: 6ddf2d7f-3d1b-4bb0-9a5e-9b0c87c7b5e6
|
||||
name: conv_var
|
||||
selector:
|
||||
- conversation
|
||||
- conv_var
|
||||
value: default
|
||||
value_type: string
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_size_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: start
|
||||
targetType: assigner
|
||||
id: start-source-assigner-target
|
||||
source: start
|
||||
sourceHandle: source
|
||||
target: assigner
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInLoop: false
|
||||
sourceType: assigner
|
||||
targetType: answer
|
||||
id: assigner-source-answer-target
|
||||
source: assigner
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
desc: ''
|
||||
selected: false
|
||||
title: Start
|
||||
type: start
|
||||
variables: []
|
||||
height: 54
|
||||
id: start
|
||||
position:
|
||||
x: 30
|
||||
y: 253
|
||||
positionAbsolute:
|
||||
x: 30
|
||||
y: 253
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
- data:
|
||||
answer: 'Current Value Of `conv_var` is:{{#conversation.conv_var#}}'
|
||||
desc: ''
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
variables: []
|
||||
height: 106
|
||||
id: answer
|
||||
position:
|
||||
x: 638
|
||||
y: 253
|
||||
positionAbsolute:
|
||||
x: 638
|
||||
y: 253
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
- data:
|
||||
assigned_variable_selector:
|
||||
- conversation
|
||||
- conv_var
|
||||
desc: ''
|
||||
input_variable_selector:
|
||||
- sys
|
||||
- query
|
||||
selected: false
|
||||
title: Variable Assigner
|
||||
type: assigner
|
||||
write_mode: over-write
|
||||
height: 84
|
||||
id: assigner
|
||||
position:
|
||||
x: 334
|
||||
y: 253
|
||||
positionAbsolute:
|
||||
x: 334
|
||||
y: 253
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
viewport:
|
||||
x: 0
|
||||
y: 0
|
||||
zoom: 0.7
|
||||
@@ -230,7 +230,6 @@ class TestAgentService:
|
||||
|
||||
# Create first agent thought
|
||||
thought1 = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
thought="I need to analyze the user's request",
|
||||
@@ -257,7 +256,6 @@ class TestAgentService:
|
||||
|
||||
# Create second agent thought
|
||||
thought2 = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=2,
|
||||
thought="Based on the analysis, I can provide a response",
|
||||
@@ -545,7 +543,6 @@ class TestAgentService:
|
||||
|
||||
# Create agent thought with tool error
|
||||
thought_with_error = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
thought="I need to analyze the user's request",
|
||||
@@ -759,7 +756,6 @@ class TestAgentService:
|
||||
|
||||
# Create agent thought with multiple tools
|
||||
complex_thought = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
thought="I need to use multiple tools to complete this task",
|
||||
@@ -877,7 +873,6 @@ class TestAgentService:
|
||||
|
||||
# Create agent thought with files
|
||||
thought_with_files = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
thought="I need to process some files",
|
||||
@@ -957,7 +952,6 @@ class TestAgentService:
|
||||
|
||||
# Create agent thought with empty tool data
|
||||
empty_thought = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
thought="I need to analyze the user's request",
|
||||
@@ -999,7 +993,6 @@ class TestAgentService:
|
||||
|
||||
# Create agent thought with malformed JSON
|
||||
malformed_thought = MessageAgentThought(
|
||||
id=fake.uuid4(),
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
thought="I need to analyze the user's request",
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestActivateCheckApi:
|
||||
"tenant": tenant,
|
||||
}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking valid invitation token.
|
||||
@@ -66,7 +66,7 @@ class TestActivateCheckApi:
|
||||
assert response["data"]["workspace_id"] == "workspace-123"
|
||||
assert response["data"]["email"] == "invitee@example.com"
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test checking invalid invitation token.
|
||||
@@ -88,7 +88,7 @@ class TestActivateCheckApi:
|
||||
# Assert
|
||||
assert response["is_valid"] is False
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without workspace ID.
|
||||
@@ -109,7 +109,7 @@ class TestActivateCheckApi:
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without email parameter.
|
||||
@@ -130,6 +130,20 @@ class TestActivateCheckApi:
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
|
||||
"""Ensure token validation uses lowercase emails."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
with app.test_request_context(
|
||||
"/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token"
|
||||
):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
|
||||
|
||||
class TestActivateApi:
|
||||
"""Test cases for account activation endpoint."""
|
||||
@@ -212,7 +226,7 @@ class TestActivateApi:
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test account activation with invalid token.
|
||||
@@ -241,7 +255,7 @@ class TestActivateApi:
|
||||
with pytest.raises(AlreadyActivateError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_sets_interface_theme(
|
||||
@@ -290,7 +304,7 @@ class TestActivateApi:
|
||||
("es-ES", "Europe/Madrid"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_with_different_locales(
|
||||
@@ -336,7 +350,7 @@ class TestActivateApi:
|
||||
assert mock_account.interface_language == language
|
||||
assert mock_account.timezone == timezone
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_returns_success_response(
|
||||
@@ -376,7 +390,7 @@ class TestActivateApi:
|
||||
# Assert
|
||||
assert response == {"result": "success"}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_without_workspace_id(
|
||||
@@ -415,3 +429,37 @@ class TestActivateApi:
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_normalizes_email_before_lookup(
|
||||
self,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
):
|
||||
"""Ensure uppercase emails are normalized before lookup and revocation."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "Invitee@Example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
assert response["result"] == "success"
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
@@ -67,7 +67,7 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
@@ -100,7 +100,7 @@ class TestAuthenticationSecurity:
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
EmailRegisterResetApi,
|
||||
EmailRegisterSendEmailApi,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_email_normalizes_and_falls_back(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_email_send_ip_limit,
|
||||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/send-email",
|
||||
method="POST",
|
||||
json={"email": "Invitee@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailRegisterSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestEmailRegisterCheckApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.generate_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit")
|
||||
def test_validity_normalizes_email_before_checks(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = EmailRegisterCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("user@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="4321", additional_data={"phase": "register"}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestEmailRegisterResetApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_creates_account_with_normalized_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
@@ -0,0 +1,176 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_normalizes_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
controller_features = SimpleNamespace(is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch(
|
||||
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||
return_value=controller_features,
|
||||
),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=mock_account,
|
||||
email="user@example.com",
|
||||
language="zh-Hans",
|
||||
is_allow_register=True,
|
||||
)
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_check_normalizes_email(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("admin@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"Admin@Example.com",
|
||||
code="4321",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("admin@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_fetches_account_with_original_email(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
@@ -76,7 +76,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@@ -120,7 +120,7 @@ class TestLoginApi:
|
||||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!")
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert response.json["result"] == "success"
|
||||
@@ -128,7 +128,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@@ -182,7 +182,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login rejection when rate limit is exceeded.
|
||||
@@ -230,7 +230,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
def test_login_fails_with_invalid_credentials(
|
||||
@@ -269,7 +269,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
def test_login_fails_for_banned_account(
|
||||
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
|
||||
@@ -298,7 +298,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@@ -343,7 +343,7 @@ class TestLoginApi:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login failure when invitation email doesn't match login email.
|
||||
@@ -371,6 +371,52 @@ class TestLoginApi:
|
||||
with pytest.raises(InvalidEmailError):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_login_retries_with_lowercase_email(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login_service,
|
||||
mock_get_tenants,
|
||||
mock_add_rate_limit,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
mock_login_service.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.json["result"] == "success"
|
||||
assert mock_authenticate.call_args_list == [
|
||||
(("Upper@Example.com", "ValidPass123!", None), {}),
|
||||
(("upper@example.com", "ValidPass123!", None), {}),
|
||||
]
|
||||
mock_add_rate_limit.assert_not_called()
|
||||
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
"""Test cases for the LogoutApi endpoint."""
|
||||
|
||||
@@ -12,6 +12,7 @@ from controllers.console.auth.oauth import (
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
@@ -215,6 +216,34 @@ class TestOAuthCallback:
|
||||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_invitation_comparison_is_case_insensitive(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_register_service,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
|
||||
id="123", name="Test User", email="User@Example.com"
|
||||
)
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_register_service.is_valid_invite_token.return_value = True
|
||||
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||
resource.get("github")
|
||||
|
||||
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
@@ -395,12 +424,12 @@ class TestAccountGeneration:
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.select")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
|
||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
@@ -410,15 +439,31 @@ class TestAccountGeneration:
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||
mock_get_account.assert_not_called()
|
||||
|
||||
# Test fallback to email
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert result == expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
@@ -466,6 +511,35 @@ class TestAccountGeneration:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
):
|
||||
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US"}):
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
|
||||
@@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import (
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi:
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_success(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
@@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
@@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi:
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_language_handling(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
@@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
@@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi:
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["token"] == "new_token"
|
||||
mock_revoke_token.assert_called_once_with("old_token")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"test@example.com", code="123456", additional_data={"phase": "reset"}
|
||||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
def test_verify_code_preserves_token_email_case(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_generate_token,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "999888", "token": "upper_token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"User@Example.com", code="999888", additional_data={"phase": "reset"}
|
||||
)
|
||||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
@@ -355,20 +395,16 @@ class TestForgotPasswordResetApi:
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||
def test_reset_password_success(
|
||||
self,
|
||||
mock_get_tenants,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
@@ -383,11 +419,8 @@ class TestForgotPasswordResetApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
||||
# Act
|
||||
@@ -475,13 +508,11 @@ class TestForgotPasswordResetApi:
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
@@ -491,11 +522,8 @@ class TestForgotPasswordResetApi:
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
|
||||
39
api/tests/unit_tests/controllers/console/test_setup.py
Normal file
39
api/tests/unit_tests/controllers/console/test_setup.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.setup import SetupApi
|
||||
|
||||
|
||||
class TestSetupApi:
|
||||
def test_post_lowercases_email_before_register(self):
|
||||
"""Ensure setup registration normalizes email casing."""
|
||||
payload = {
|
||||
"email": "Admin@Example.com",
|
||||
"name": "Admin User",
|
||||
"password": "ValidPass123!",
|
||||
"language": "en-US",
|
||||
}
|
||||
setup_api = SetupApi(api=None)
|
||||
|
||||
mock_console_ns = SimpleNamespace(payload=payload)
|
||||
|
||||
with (
|
||||
patch("controllers.console.setup.console_ns", mock_console_ns),
|
||||
patch("controllers.console.setup.get_setup_status", return_value=False),
|
||||
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
|
||||
patch("controllers.console.setup.get_init_validate_status", return_value=True),
|
||||
patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"),
|
||||
patch("controllers.console.setup.request", object()),
|
||||
patch("controllers.console.setup.RegisterService.setup") as mock_register,
|
||||
):
|
||||
response, status = setup_api.post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
assert status == 201
|
||||
mock_register.assert_called_once_with(
|
||||
email="admin@example.com",
|
||||
name=payload["name"],
|
||||
password=payload["password"],
|
||||
ip_address="127.0.0.1",
|
||||
language=payload["language"],
|
||||
)
|
||||
@@ -0,0 +1,247 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
ChangeEmailResetApi,
|
||||
ChangeEmailSendEmailApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
account.email = email
|
||||
account.id = account_id
|
||||
account.status = "active"
|
||||
account._current_tenant = tenant_obj
|
||||
return account
|
||||
|
||||
|
||||
def _set_logged_in_user(account: Account):
|
||||
g._login_user = account
|
||||
g._current_tenant = account.current_tenant
|
||||
|
||||
|
||||
class TestChangeEmailSend:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-abc"}
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=None,
|
||||
email="new@example.com",
|
||||
old_email="current@example.com",
|
||||
language="en-US",
|
||||
phase="new_email",
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_validate_with_normalized_email(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_before_update(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "New@Example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_is_freeze.assert_called_once_with("new@example.com")
|
||||
mock_check_unique.assert_called_once_with("new@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "feedback": "test"},
|
||||
):
|
||||
response = AccountDeleteUpdateFeedbackApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_update.assert_called_once_with("User@Example.com", "test")
|
||||
|
||||
|
||||
class TestCheckEmailUnique:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/check-email-unique",
|
||||
method="POST",
|
||||
json={"email": "Case@Test.com"},
|
||||
):
|
||||
response = CheckEmailUnique().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_is_freeze.assert_called_once_with("case@test.com")
|
||||
mock_check_unique.assert_called_once_with("case@test.com")
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
session = MagicMock()
|
||||
first = MagicMock()
|
||||
first.scalar_one_or_none.return_value = None
|
||||
second = MagicMock()
|
||||
expected_account = MagicMock()
|
||||
second.scalar_one_or_none.return_value = expected_account
|
||||
session.execute.side_effect = [first, second]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
|
||||
|
||||
assert result is expected_account
|
||||
assert session.execute.call_count == 2
|
||||
@@ -0,0 +1,82 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.workspace.members import MemberInviteEmailApi
|
||||
from models.account import Account, TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
return SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
workspace_members=workspace_members,
|
||||
members=placeholder_quota,
|
||||
apps=placeholder_quota,
|
||||
vector_space=placeholder_quota,
|
||||
documents_upload_quota=placeholder_quota,
|
||||
annotation_quota_limit=placeholder_quota,
|
||||
)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
@patch("controllers.console.workspace.members.FeatureService.get_features")
|
||||
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
|
||||
@patch("controllers.console.workspace.members.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
def test_invite_normalizes_emails(
|
||||
self,
|
||||
mock_csrf,
|
||||
mock_db,
|
||||
mock_current_account,
|
||||
mock_invite_member,
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
|
||||
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
|
||||
mock_current_account.return_value = (inviter, tenant.id)
|
||||
|
||||
with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"):
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/members/invite-email",
|
||||
method="POST",
|
||||
json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"},
|
||||
):
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account._current_tenant = tenant
|
||||
g._login_user = account
|
||||
g._current_tenant = tenant
|
||||
response, status_code = MemberInviteEmailApi().post()
|
||||
|
||||
assert status_code == 201
|
||||
assert response["invitation_results"][0]["email"] == "user@example.com"
|
||||
|
||||
assert mock_invite_member.call_count == 1
|
||||
call_args = mock_invite_member.call_args
|
||||
assert call_args.kwargs["tenant"] == tenant
|
||||
assert call_args.kwargs["email"] == "User@Example.com"
|
||||
assert call_args.kwargs["language"] == "en-US"
|
||||
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
|
||||
assert call_args.kwargs["inviter"] == inviter
|
||||
mock_csrf.assert_called_once()
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Unit tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.forgot_password using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.forgot_password"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
forgot_password_module = _load_controller_module()
|
||||
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
|
||||
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
|
||||
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_web_endpoint_guards():
|
||||
"""Stub enterprise and feature toggles used by route decorators."""
|
||||
|
||||
features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_controller_db():
|
||||
"""Replace controller-level db reference with a simple stub."""
|
||||
|
||||
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
|
||||
fake_wraps_db = SimpleNamespace(
|
||||
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
|
||||
)
|
||||
with (
|
||||
patch("controllers.web.forgot_password.db", fake_db),
|
||||
patch("controllers.console.wraps.db", fake_wraps_db),
|
||||
):
|
||||
yield fake_db
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
|
||||
def test_send_reset_email_success(
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password returns token when email exists and limits allow."""
|
||||
|
||||
mock_account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "user@example.com"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "reset-token"}
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
|
||||
def test_check_token_success(
|
||||
mock_is_rate_limited: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke: MagicMock,
|
||||
mock_generate: MagicMock,
|
||||
mock_reset_limit: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/validity validates the code and refreshes token."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limited.assert_called_once_with("user@example.com")
|
||||
mock_get_data.assert_called_once_with("old-token")
|
||||
mock_revoke.assert_called_once_with("old-token")
|
||||
mock_generate.assert_called_once_with(
|
||||
"user@example.com",
|
||||
code="123456",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_success(
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_token_bytes: MagicMock,
|
||||
mock_hash_password: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/resets updates the stored password when token is valid."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
|
||||
account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
session_ctx.commit.assert_called_once()
|
||||
226
api/tests/unit_tests/controllers/web/test_web_forgot_password.py
Normal file
226
api/tests/unit_tests/controllers/web/test_web_forgot_password.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config", dify_settings),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
def test_should_normalize_email_before_sending(
|
||||
self,
|
||||
mock_session_cls,
|
||||
mock_extract_ip,
|
||||
mock_rate_limit,
|
||||
mock_get_account,
|
||||
mock_send_mail,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_should_normalize_email_for_validity_checks(
|
||||
self,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"User@Example.com",
|
||||
code="1234",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_should_preserve_token_email_case(
|
||||
self,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"MixedCase@Example.com",
|
||||
code="5678",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_revoke_token.assert_called_once_with("token-upper")
|
||||
mock_reset_rate.assert_called_once_with("mixedcase@example.com")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_should_fetch_account_with_fallback(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_should_update_password_and_commit(
|
||||
self,
|
||||
mock_get_account,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_token_bytes,
|
||||
mock_hash_password,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||
account = MagicMock()
|
||||
mock_get_account.return_value = account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
mock_session.commit.assert_called_once()
|
||||
91
api/tests/unit_tests/controllers/web/test_web_login.py
Normal file
91
api/tests/unit_tests/controllers/web/test_web_login.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
|
||||
web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config", console_dify),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendEmailApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.send_email_code_login_email")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
|
||||
def test_should_fetch_account_with_original_email(
|
||||
self,
|
||||
mock_get_user,
|
||||
mock_send_email,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailCodeLoginSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
|
||||
|
||||
|
||||
class TestEmailCodeLoginApi:
|
||||
@patch("controllers.web.login.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
|
||||
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data")
|
||||
def test_should_normalize_email_before_validating(
|
||||
self,
|
||||
mock_get_token_data,
|
||||
mock_revoke_token,
|
||||
mock_get_user,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_user.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
response = EmailCodeLoginApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
@@ -1,390 +0,0 @@
|
||||
"""
|
||||
Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method,
|
||||
specifically testing the ANSWER node message_replace logic.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
from core.workflow.enums import NodeType
|
||||
from models import EndUser
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAnswerNodeMessageReplace:
|
||||
"""Test cases for ANSWER node message_replace event logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=AdvancedChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
entity.workflow_run_id = "test-workflow-run-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
entity.query = "test query"
|
||||
entity.files = []
|
||||
entity.extras = {}
|
||||
entity.trace_manager = None
|
||||
entity.inputs = {}
|
||||
entity.invoke_from = "debugger"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(self):
|
||||
"""Create a mock workflow."""
|
||||
workflow = Mock()
|
||||
workflow.id = "test-workflow-id"
|
||||
workflow.features_dict = {}
|
||||
return workflow
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock()
|
||||
manager.listen.return_value = []
|
||||
manager.graph_runtime_state = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "advanced_chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.query = "test query"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock end user."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = "test-user-id"
|
||||
user.session_id = "test-session-id"
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_draft_var_saver_factory(self):
|
||||
"""Create a mock draft variable saver factory."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_workflow,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_user,
|
||||
mock_draft_var_saver_factory,
|
||||
):
|
||||
"""Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies."""
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
|
||||
with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"):
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
workflow=mock_workflow,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
user=mock_user,
|
||||
stream=True,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=mock_draft_var_saver_factory,
|
||||
)
|
||||
# Initialize workflow run id to avoid validation errors
|
||||
pipeline._workflow_run_id = "test-workflow-run-id"
|
||||
# Mock the message cycle manager methods we need to track
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = Mock()
|
||||
return pipeline
|
||||
|
||||
def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity):
|
||||
"""
|
||||
Test that when an ANSWER node's final output differs from accumulated answer,
|
||||
a message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "initial answer"
|
||||
|
||||
# Create ANSWER node succeeded event with different final output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "updated final answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter to avoid extra processing
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert
|
||||
assert pipeline._task_state.answer == "updated final answer"
|
||||
# Verify message_replace was called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
|
||||
answer="updated final answer", reason="variable_update"
|
||||
)
|
||||
|
||||
def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's final output is the same as accumulated answer,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "same answer"
|
||||
|
||||
# Create ANSWER node succeeded event with same output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "same answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "same answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's output is None or missing 'answer' key,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with None output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": None},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node has empty outputs dict,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with empty outputs
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_no_answer_key_in_outputs(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's outputs don't contain 'answer' key,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event without 'answer' key in outputs
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"other_key": "some value"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_non_answer_node_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Test with LLM node
|
||||
llm_event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-llm-execution-id",
|
||||
node_id="test-llm-node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "different answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(llm_event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_end_node_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that END nodes don't trigger message_replace events even with 'answer' output.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create END node succeeded event with answer output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-end-execution-id",
|
||||
node_id="test-end-node",
|
||||
node_type=NodeType.END,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "different answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_numeric_output_converts_to_string(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's final output is numeric,
|
||||
it gets converted to string properly.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "text answer"
|
||||
|
||||
# Create ANSWER node succeeded event with numeric output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": 12345},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should be converted to string
|
||||
assert pipeline._task_state.answer == "12345"
|
||||
# Verify message_replace was called with string
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
|
||||
answer="12345", reason="variable_update"
|
||||
)
|
||||
|
||||
def test_answer_node_files_are_recorded(self, pipeline):
|
||||
"""
|
||||
Test that ANSWER nodes properly record files from outputs.
|
||||
"""
|
||||
# Arrange
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with files
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={
|
||||
"answer": "same answer",
|
||||
"files": [
|
||||
{"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"])
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: files should be recorded
|
||||
assert len(pipeline._recorded_files) == 1
|
||||
assert pipeline._recorded_files[0] == event.outputs["files"][0]
|
||||
@@ -45,3 +45,33 @@ def test_streaming_conversation_variables():
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
|
||||
|
||||
def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment():
|
||||
fixture_name = "test_streaming_conversation_variables_v1_overwrite"
|
||||
input_query = "overwrite-value"
|
||||
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
use_auto_mock=False,
|
||||
mock_config=MockConfigBuilder().build(),
|
||||
query=input_query,
|
||||
inputs={},
|
||||
expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"},
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
|
||||
events = result.events
|
||||
conv_var_chunk_events = [
|
||||
event
|
||||
for event in events
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var")
|
||||
]
|
||||
|
||||
assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted"
|
||||
assert all(event.chunk == input_query for event in conv_var_chunk_events), (
|
||||
"Expected streamed conversation variable value to match the input query"
|
||||
)
|
||||
|
||||
@@ -58,6 +58,8 @@ def test_json_object_valid_schema():
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -68,7 +70,7 @@ def test_json_object_valid_schema():
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
@@ -87,6 +89,8 @@ def test_json_object_invalid_json_string():
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -97,12 +101,12 @@ def test_json_object_invalid_json_string():
|
||||
)
|
||||
]
|
||||
|
||||
# Missing closing brace makes this invalid JSON
|
||||
# Providing a string instead of an object should raise a type error
|
||||
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@@ -118,6 +122,8 @@ def test_json_object_does_not_match_schema():
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -129,7 +135,7 @@ def test_json_object_does_not_match_schema():
|
||||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@@ -149,6 +155,8 @@ def test_json_object_missing_required_schema_field():
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -160,7 +168,7 @@ def test_json_object_missing_required_schema_field():
|
||||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
|
||||
@@ -2,13 +2,17 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.enums import FileType
|
||||
from core.file.models import File, FileTransferMethod
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@@ -96,6 +100,58 @@ class TestWorkflowEntry:
|
||||
assert output_var is not None
|
||||
assert output_var.value == "system_user"
|
||||
|
||||
def test_single_step_run_injects_code_limits(self):
|
||||
"""Ensure single-step CodeNode execution configures limits."""
|
||||
# Arrange
|
||||
node_id = "code_node"
|
||||
node_data = {
|
||||
"type": "code",
|
||||
"title": "Code",
|
||||
"desc": None,
|
||||
"variables": [],
|
||||
"code_language": CodeLanguage.PYTHON3,
|
||||
"code": "def main():\n return {}",
|
||||
"outputs": {},
|
||||
}
|
||||
node_config = {"id": node_id, "data": node_data}
|
||||
|
||||
class StubWorkflow:
|
||||
def __init__(self):
|
||||
self.tenant_id = "tenant"
|
||||
self.app_id = "app"
|
||||
self.id = "workflow"
|
||||
self.graph_dict = {"nodes": [node_config], "edges": []}
|
||||
|
||||
def get_node_config_by_id(self, target_id: str):
|
||||
assert target_id == node_id
|
||||
return node_config
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
|
||||
expected_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
|
||||
# Act
|
||||
node, _ = WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id="user",
|
||||
user_inputs={},
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(node, CodeNode)
|
||||
assert node._limits == expected_limits
|
||||
|
||||
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
|
||||
"""Test mapping environment variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with environment variables
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
@@ -104,6 +105,42 @@ class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
return pause
|
||||
|
||||
|
||||
class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
scalar_result = Mock()
|
||||
scalar_result.all.return_value = []
|
||||
mock_session.scalars.return_value = scalar_result
|
||||
|
||||
repository.get_runs_batch_by_time_range(
|
||||
start_from=None,
|
||||
end_before=datetime(2024, 1, 1),
|
||||
last_seen=None,
|
||||
batch_size=50,
|
||||
)
|
||||
|
||||
stmt = mock_session.scalars.call_args[0][0]
|
||||
compiled_sql = str(
|
||||
stmt.compile(
|
||||
dialect=postgresql.dialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "workflow_runs.status" in compiled_sql
|
||||
for status in (
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
):
|
||||
assert f"'{status.value}'" in compiled_sql
|
||||
|
||||
assert "'running'" not in compiled_sql
|
||||
assert "'paused'" not in compiled_sql
|
||||
|
||||
|
||||
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test create_workflow_pause method."""
|
||||
|
||||
@@ -181,6 +218,61 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
|
||||
node_ids_result = Mock()
|
||||
node_ids_result.all.return_value = []
|
||||
pause_ids_result = Mock()
|
||||
pause_ids_result.all.return_value = []
|
||||
mock_session.scalars.side_effect = [node_ids_result, pause_ids_result]
|
||||
|
||||
# app_logs delete, runs delete
|
||||
mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)]
|
||||
|
||||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.delete_by_run_ids.return_value = 3
|
||||
|
||||
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
|
||||
counts = repository.delete_runs_with_related(
|
||||
[run],
|
||||
delete_node_executions=lambda session, runs: (2, 1),
|
||||
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
|
||||
)
|
||||
|
||||
fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"])
|
||||
assert counts["node_executions"] == 2
|
||||
assert counts["offloads"] == 1
|
||||
assert counts["trigger_logs"] == 3
|
||||
assert counts["runs"] == 1
|
||||
|
||||
|
||||
class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
|
||||
pause_ids_result = Mock()
|
||||
pause_ids_result.all.return_value = ["pause-1", "pause-2"]
|
||||
mock_session.scalars.return_value = pause_ids_result
|
||||
mock_session.scalar.side_effect = [5, 2]
|
||||
|
||||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.count_by_run_ids.return_value = 3
|
||||
|
||||
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
|
||||
counts = repository.count_runs_with_related(
|
||||
[run],
|
||||
count_node_executions=lambda session, runs: (2, 1),
|
||||
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
|
||||
)
|
||||
|
||||
fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"])
|
||||
assert counts["node_executions"] == 2
|
||||
assert counts["offloads"] == 1
|
||||
assert counts["trigger_logs"] == 3
|
||||
assert counts["app_logs"] == 5
|
||||
assert counts["pauses"] == 2
|
||||
assert counts["pause_reasons"] == 2
|
||||
assert counts["runs"] == 1
|
||||
|
||||
|
||||
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test resume_workflow_pause method."""
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
|
||||
|
||||
def test_delete_by_run_ids_executes_delete():
|
||||
session = Mock(spec=Session)
|
||||
session.execute.return_value = Mock(rowcount=2)
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
deleted = repo.delete_by_run_ids(["run-1", "run-2"])
|
||||
|
||||
stmt = session.execute.call_args[0][0]
|
||||
compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
|
||||
assert "workflow_trigger_logs" in compiled_sql
|
||||
assert "'run-1'" in compiled_sql
|
||||
assert "'run-2'" in compiled_sql
|
||||
assert deleted == 2
|
||||
|
||||
|
||||
def test_delete_by_run_ids_empty_short_circuits():
|
||||
session = Mock(spec=Session)
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
deleted = repo.delete_by_run_ids([])
|
||||
|
||||
session.execute.assert_not_called()
|
||||
assert deleted == 0
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from models.account import Account
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
@@ -1147,9 +1147,13 @@ class TestRegisterService:
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||
|
||||
with patch("services.account_service.Session") as mock_session_class:
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
# Mock RegisterService.register
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
@@ -1182,9 +1186,59 @@ class TestRegisterService:
|
||||
email="newuser@example.com",
|
||||
name="newuser",
|
||||
language="en-US",
|
||||
status="pending",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
|
||||
|
||||
def test_invite_new_member_normalizes_new_account_email(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
):
|
||||
"""Ensure inviting with mixed-case email normalizes before registering."""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-456"
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
mixed_email = "Invitee@Example.com"
|
||||
|
||||
mock_session = MagicMock()
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending"
|
||||
)
|
||||
with patch("services.account_service.RegisterService.register") as mock_register:
|
||||
mock_register.return_value = mock_new_account
|
||||
with (
|
||||
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
|
||||
patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
|
||||
patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
|
||||
patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
|
||||
):
|
||||
mock_generate_token.return_value = "invite-token-abc"
|
||||
|
||||
RegisterService.invite_new_member(
|
||||
tenant=mock_tenant,
|
||||
email=mixed_email,
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
)
|
||||
|
||||
mock_register.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
name="invitee",
|
||||
language="en-US",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
|
||||
@@ -1207,9 +1261,13 @@ class TestRegisterService:
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||
|
||||
with patch("services.account_service.Session") as mock_session_class:
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock the db.session.query for TenantAccountJoin
|
||||
mock_db_query = MagicMock()
|
||||
@@ -1238,6 +1296,7 @@ class TestRegisterService:
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
||||
mock_task_dependencies.delay.assert_called_once()
|
||||
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
|
||||
|
||||
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
||||
"""Test inviting a member who is already in the tenant."""
|
||||
@@ -1251,7 +1310,6 @@ class TestRegisterService:
|
||||
|
||||
# Mock database queries
|
||||
query_results = {
|
||||
("Account", "email", "existing@example.com"): mock_existing_account,
|
||||
(
|
||||
"TenantAccountJoin",
|
||||
"tenant_id",
|
||||
@@ -1261,7 +1319,11 @@ class TestRegisterService:
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
|
||||
# Mock TenantService methods
|
||||
with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission:
|
||||
with (
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
|
||||
):
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountAlreadyInTenantError,
|
||||
@@ -1272,6 +1334,7 @@ class TestRegisterService:
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
)
|
||||
mock_lookup.assert_called_once()
|
||||
|
||||
def test_invite_new_member_no_inviter(self):
|
||||
"""Test inviting a member without providing an inviter."""
|
||||
@@ -1497,6 +1560,30 @@ class TestRegisterService:
|
||||
# Verify results
|
||||
assert result is None
|
||||
|
||||
def test_get_invitation_with_case_fallback_returns_initial_match(self):
|
||||
"""Fallback helper should return the initial invitation when present."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch(
|
||||
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
|
||||
) as mock_get:
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
assert result == invitation
|
||||
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
|
||||
"""Fallback helper should retry with lowercase email when needed."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
|
||||
mock_get.side_effect = [None, invitation]
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
assert result == invitation
|
||||
assert mock_get.call_args_list == [
|
||||
(("tenant-456", "User@Test.com", "token-123"),),
|
||||
(("tenant-456", "user@test.com", "token-123"),),
|
||||
]
|
||||
|
||||
# ==================== Helper Method Tests ====================
|
||||
|
||||
def test_get_invitation_token_key(self):
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from services.billing_service import SubscriptionPlan
|
||||
from services.retention.workflow_run import clear_free_plan_expired_workflow_run_logs as cleanup_module
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
|
||||
|
||||
class FakeRun:
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
tenant_id: str,
|
||||
created_at: datetime.datetime,
|
||||
app_id: str = "app-1",
|
||||
workflow_id: str = "wf-1",
|
||||
triggered_from: str = "workflow-run",
|
||||
) -> None:
|
||||
self.id = run_id
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.workflow_id = workflow_id
|
||||
self.triggered_from = triggered_from
|
||||
self.created_at = created_at
|
||||
|
||||
|
||||
class FakeRepo:
|
||||
def __init__(
|
||||
self,
|
||||
batches: list[list[FakeRun]],
|
||||
delete_result: dict[str, int] | None = None,
|
||||
count_result: dict[str, int] | None = None,
|
||||
) -> None:
|
||||
self.batches = batches
|
||||
self.call_idx = 0
|
||||
self.deleted: list[list[str]] = []
|
||||
self.counted: list[list[str]] = []
|
||||
self.delete_result = delete_result or {
|
||||
"runs": 0,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
self.count_result = count_result or {
|
||||
"runs": 0,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
|
||||
def get_runs_batch_by_time_range(
|
||||
self,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
last_seen: tuple[datetime.datetime, str] | None,
|
||||
batch_size: int,
|
||||
) -> list[FakeRun]:
|
||||
if self.call_idx >= len(self.batches):
|
||||
return []
|
||||
batch = self.batches[self.call_idx]
|
||||
self.call_idx += 1
|
||||
return batch
|
||||
|
||||
def delete_runs_with_related(
|
||||
self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None
|
||||
) -> dict[str, int]:
|
||||
self.deleted.append([run.id for run in runs])
|
||||
result = self.delete_result.copy()
|
||||
result["runs"] = len(runs)
|
||||
return result
|
||||
|
||||
def count_runs_with_related(
|
||||
self, runs: list[FakeRun], count_node_executions=None, count_trigger_logs=None
|
||||
) -> dict[str, int]:
|
||||
self.counted.append([run.id for run in runs])
|
||||
result = self.count_result.copy()
|
||||
result["runs"] = len(runs)
|
||||
return result
|
||||
|
||||
|
||||
def plan_info(plan: str, expiration: int) -> SubscriptionPlan:
|
||||
return SubscriptionPlan(plan=plan, expiration_date=expiration)
|
||||
|
||||
|
||||
def create_cleanup(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
repo: FakeRepo,
|
||||
*,
|
||||
grace_period_days: int = 0,
|
||||
whitelist: set[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> WorkflowRunCleanup:
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.dify_config,
|
||||
"SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD",
|
||||
grace_period_days,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.WorkflowRunCleanup,
|
||||
"_get_cleanup_whitelist",
|
||||
lambda self: whitelist or set(),
|
||||
)
|
||||
return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs)
|
||||
|
||||
|
||||
def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
def fail_bulk(_: list[str]) -> dict[str, SubscriptionPlan]:
|
||||
raise RuntimeError("should not call")
|
||||
|
||||
monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fail_bulk))
|
||||
|
||||
tenants = {"t1", "t2"}
|
||||
free = cleanup._filter_free_tenants(tenants)
|
||||
|
||||
assert free == tenants
|
||||
|
||||
|
||||
def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.BillingService,
|
||||
"get_plan_bulk_with_cache",
|
||||
staticmethod(
|
||||
lambda tenant_ids: {
|
||||
tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1))
|
||||
for tenant_id in tenant_ids
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"})
|
||||
|
||||
assert free == {"t_free", "t_missing"}
|
||||
|
||||
|
||||
def test_filter_free_tenants_respects_grace_period(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, grace_period_days=45)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
within_grace_ts = int((now - datetime.timedelta(days=10)).timestamp())
|
||||
outside_grace_ts = int((now - datetime.timedelta(days=90)).timestamp())
|
||||
|
||||
def fake_bulk(_: list[str]) -> dict[str, SubscriptionPlan]:
|
||||
return {
|
||||
"recently_downgraded": plan_info("sandbox", within_grace_ts),
|
||||
"long_sandbox": plan_info("sandbox", outside_grace_ts),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fake_bulk))
|
||||
|
||||
free = cleanup._filter_free_tenants({"recently_downgraded", "long_sandbox"})
|
||||
|
||||
assert free == {"long_sandbox"}
|
||||
|
||||
|
||||
def test_filter_free_tenants_skips_cleanup_whitelist(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(
|
||||
monkeypatch,
|
||||
repo=FakeRepo([]),
|
||||
days=30,
|
||||
batch_size=10,
|
||||
whitelist={"tenant_whitelist"},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.BillingService,
|
||||
"get_plan_bulk_with_cache",
|
||||
staticmethod(
|
||||
lambda tenant_ids: {
|
||||
tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1))
|
||||
for tenant_id in tenant_ids
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tenants = {"tenant_whitelist", "tenant_regular"}
|
||||
free = cleanup._filter_free_tenants(tenants)
|
||||
|
||||
assert free == {"tenant_regular"}
|
||||
|
||||
|
||||
def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.BillingService,
|
||||
"get_plan_bulk_with_cache",
|
||||
staticmethod(lambda tenant_ids: (_ for _ in ()).throw(RuntimeError("boom"))),
|
||||
)
|
||||
|
||||
free = cleanup._filter_free_tenants({"t1", "t2"})
|
||||
|
||||
assert free == set()
|
||||
|
||||
|
||||
def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
batches=[
|
||||
[
|
||||
FakeRun("run-free", "t_free", cutoff),
|
||||
FakeRun("run-paid", "t_paid", cutoff),
|
||||
]
|
||||
]
|
||||
)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.BillingService,
|
||||
"get_plan_bulk_with_cache",
|
||||
staticmethod(
|
||||
lambda tenant_ids: {
|
||||
tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1))
|
||||
for tenant_id in tenant_ids
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
cleanup.run()
|
||||
|
||||
assert repo.deleted == [["run-free"]]
|
||||
|
||||
|
||||
def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]])
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
cleanup_module.BillingService,
|
||||
"get_plan_bulk_with_cache",
|
||||
staticmethod(lambda tenant_ids: {tenant_id: plan_info("team", 1893456000) for tenant_id in tenant_ids}),
|
||||
)
|
||||
|
||||
cleanup.run()
|
||||
|
||||
assert repo.deleted == []
|
||||
|
||||
|
||||
def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
cleanup.run()
|
||||
|
||||
|
||||
def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
batches=[[FakeRun("run-free", "t_free", cutoff)]],
|
||||
count_result={
|
||||
"runs": 0,
|
||||
"node_executions": 2,
|
||||
"offloads": 1,
|
||||
"app_logs": 3,
|
||||
"trigger_logs": 4,
|
||||
"pauses": 5,
|
||||
"pause_reasons": 6,
|
||||
},
|
||||
)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10, dry_run=True)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
cleanup.run()
|
||||
|
||||
assert repo.deleted == []
|
||||
assert repo.counted == [["run-free"]]
|
||||
captured = capsys.readouterr().out
|
||||
assert "Dry run mode enabled" in captured
|
||||
assert "would delete 1 runs" in captured
|
||||
assert "related records" in captured
|
||||
assert "node_executions 2" in captured
|
||||
assert "offloads 1" in captured
|
||||
assert "app_logs 3" in captured
|
||||
assert "trigger_logs 4" in captured
|
||||
assert "pauses 5" in captured
|
||||
assert "pause_reasons 6" in captured
|
||||
|
||||
|
||||
def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
start_from = datetime.datetime(2024, 5, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 6, 1, 0, 0, 0)
|
||||
cleanup = create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before
|
||||
)
|
||||
|
||||
assert cleanup.window_start == start_from
|
||||
assert cleanup.window_end == end_before
|
||||
|
||||
|
||||
def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=datetime.datetime.now(), end_before=None
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=None, end_before=datetime.datetime.now()
|
||||
)
|
||||
|
||||
|
||||
def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
start_from = datetime.datetime(2024, 6, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 5, 1, 0, 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before
|
||||
)
|
||||
55
api/uv.lock
generated
55
api/uv.lock
generated
@@ -1368,7 +1368,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.11.2"
|
||||
version = "1.11.4"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
@@ -1731,7 +1731,7 @@ storage = [
|
||||
{ name = "opendal", specifier = "~=0.46.0" },
|
||||
{ name = "oss2", specifier = "==2.18.5" },
|
||||
{ name = "supabase", specifier = "~=2.18.1" },
|
||||
{ name = "tos", specifier = "~=2.7.1" },
|
||||
{ name = "tos", specifier = "~=2.9.0" },
|
||||
]
|
||||
tools = [
|
||||
{ name = "cloudscraper", specifier = "~=1.2.71" },
|
||||
@@ -6148,7 +6148,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "tos"
|
||||
version = "2.7.2"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "crcmod" },
|
||||
@@ -6156,8 +6156,9 @@ dependencies = [
|
||||
{ name = "pytz" },
|
||||
{ name = "requests" },
|
||||
{ name = "six" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/01/f811af86f1f80d5f289be075c3b281e74bf3fe081cfbe5cfce44954d2c3a/tos-2.7.2.tar.gz", hash = "sha256:3c31257716785bca7b2cac51474ff32543cda94075a7b7aff70d769c15c7b7ed", size = 123407, upload-time = "2024-10-16T15:59:08.634Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9a/b3/13451226f564f88d9db2323e9b7eabcced792a0ad5ee1e333751a7634257/tos-2.9.0.tar.gz", hash = "sha256:861cfc348e770f099f911cb96b2c41774ada6c9c51b7a89d97e0c426074dd99e", size = 157071, upload-time = "2026-01-06T04:13:08.921Z" }
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
@@ -7146,31 +7147,31 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.17.3"
|
||||
version = "1.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0", size = 55547, upload-time = "2025-08-12T05:53:21.714Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/95/4c/063a912e20bcef7124e0df97282a8af3ff3e4b603ce84c481d6d7346be0a/wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d", size = 53972, upload-time = "2023-11-09T06:33:30.191Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/52/db/00e2a219213856074a213503fdac0511203dceefff26e1daa15250cc01a0/wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7", size = 53482, upload-time = "2025-08-12T05:51:45.79Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/30/ca3c4a5eba478408572096fe9ce36e6e915994dd26a4e9e98b4f729c06d9/wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85", size = 38674, upload-time = "2025-08-12T05:51:34.629Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/25/3e8cc2c46b5329c5957cec959cb76a10718e1a513309c31399a4dad07eb3/wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f", size = 38959, upload-time = "2025-08-12T05:51:56.074Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/8f/a32a99fc03e4b37e31b57cb9cefc65050ea08147a8ce12f288616b05ef54/wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311", size = 82376, upload-time = "2025-08-12T05:52:32.134Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/57/4930cb8d9d70d59c27ee1332a318c20291749b4fba31f113c2f8ac49a72e/wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1", size = 83604, upload-time = "2025-08-12T05:52:11.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/f3/1afd48de81d63dd66e01b263a6fbb86e1b5053b419b9b33d13e1f6d0f7d0/wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5", size = 82782, upload-time = "2025-08-12T05:52:12.626Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/d7/4ad5327612173b144998232f98a85bb24b60c352afb73bc48e3e0d2bdc4e/wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2", size = 82076, upload-time = "2025-08-12T05:52:33.168Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/59/e0adfc831674a65694f18ea6dc821f9fcb9ec82c2ce7e3d73a88ba2e8718/wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89", size = 36457, upload-time = "2025-08-12T05:53:03.936Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/88/16b7231ba49861b6f75fc309b11012ede4d6b0a9c90969d9e0db8d991aeb/wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77", size = 38745, upload-time = "2025-08-12T05:53:02.885Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/1e/c4d4f3398ec073012c51d1c8d87f715f56765444e1a4b11e5180577b7e6e/wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a", size = 36806, upload-time = "2025-08-12T05:52:53.368Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0", size = 53998, upload-time = "2025-08-12T05:51:47.138Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba", size = 39020, upload-time = "2025-08-12T05:51:35.906Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd", size = 39098, upload-time = "2025-08-12T05:51:57.474Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828", size = 88036, upload-time = "2025-08-12T05:52:34.784Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9", size = 88156, upload-time = "2025-08-12T05:52:13.599Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396", size = 87102, upload-time = "2025-08-12T05:52:14.56Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc", size = 87732, upload-time = "2025-08-12T05:52:36.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe", size = 36705, upload-time = "2025-08-12T05:53:07.123Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c", size = 38877, upload-time = "2025-08-12T05:53:05.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6", size = 36885, upload-time = "2025-08-12T05:52:54.367Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/03/c188ac517f402775b90d6f312955a5e53b866c964b32119f2ed76315697e/wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09", size = 37313, upload-time = "2023-11-09T06:31:52.168Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/16/ea627d7817394db04518f62934a5de59874b587b792300991b3c347ff5e0/wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d", size = 38164, upload-time = "2023-11-09T06:31:53.522Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/a7/f1212ba098f3de0fd244e2de0f8791ad2539c03bef6c05a9fcb03e45b089/wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389", size = 80890, upload-time = "2023-11-09T06:31:55.247Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/96/bb5e08b3d6db003c9ab219c487714c13a237ee7dcc572a555eaf1ce7dc82/wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060", size = 73118, upload-time = "2023-11-09T06:31:57.023Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/52/2da48b35193e39ac53cfb141467d9f259851522d0e8c87153f0ba4205fb1/wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1", size = 80746, upload-time = "2023-11-09T06:31:58.686Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/fb/18ec40265ab81c0e82a934de04596b6ce972c27ba2592c8b53d5585e6bcd/wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3", size = 85668, upload-time = "2023-11-09T06:31:59.992Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/ef/0ecb1fa23145560431b970418dce575cfaec555ab08617d82eb92afc7ccf/wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956", size = 78556, upload-time = "2023-11-09T06:32:01.942Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/62/cd284b2b747f175b5a96cbd8092b32e7369edab0644c45784871528eb852/wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d", size = 85712, upload-time = "2023-11-09T06:32:03.686Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/a7/47b7ff74fbadf81b696872d5ba504966591a3468f1bc86bca2f407baef68/wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362", size = 35327, upload-time = "2023-11-09T06:32:05.284Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/c3/0084351951d9579ae83a3d9e38c140371e4c6b038136909235079f2e6e78/wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89", size = 37523, upload-time = "2023-11-09T06:32:07.17Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/17/224132494c1e23521868cdd57cd1e903f3b6a7ba6996b7b8f077ff8ac7fe/wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b", size = 37614, upload-time = "2023-11-09T06:32:08.859Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/d7/cfcd73e8f4858079ac59d9db1ec5a1349bc486ae8e9ba55698cc1f4a1dff/wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36", size = 38316, upload-time = "2023-11-09T06:32:10.719Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/79/5ff0a5c54bda5aec75b36453d06be4f83d5cd4932cc84b7cb2b52cee23e2/wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73", size = 86322, upload-time = "2023-11-09T06:32:12.592Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/81/e799bf5d419f422d8712108837c1d9bf6ebe3cb2a81ad94413449543a923/wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809", size = 79055, upload-time = "2023-11-09T06:32:14.394Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/62/30ca2405de6a20448ee557ab2cd61ab9c5900be7cbd18a2639db595f0b98/wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b", size = 87291, upload-time = "2023-11-09T06:32:16.201Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/49/4e/5d2f6d7b57fc9956bf06e944eb00463551f7d52fc73ca35cfc4c2cdb7aed/wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81", size = 90374, upload-time = "2023-11-09T06:32:18.052Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/9b/c2c21b44ff5b9bf14a83252a8b973fb84923764ff63db3e6dfc3895cf2e0/wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9", size = 83896, upload-time = "2023-11-09T06:32:19.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/26/93a9fa02c6f257df54d7570dfe8011995138118d11939a4ecd82cb849613/wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c", size = 91738, upload-time = "2023-11-09T06:32:20.989Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/5b/4660897233eb2c8c4de3dc7cefed114c61bacb3c28327e64150dc44ee2f6/wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc", size = 35568, upload-time = "2023-11-09T06:32:22.715Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/cc/8297f9658506b224aa4bd71906447dea6bb0ba629861a758c28f67428b91/wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8", size = 37653, upload-time = "2023-11-09T06:32:24.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/21/abdedb4cdf6ff41ebf01a74087740a709e2edb146490e4d9beea054b0b7a/wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1", size = 23362, upload-time = "2023-11-09T06:33:28.271Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1478,6 +1478,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
|
||||
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
|
||||
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
|
||||
ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.11.2
|
||||
image: langgenius/dify-web:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -662,6 +662,7 @@ x-shared-env: &shared-api-worker-env
|
||||
ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false}
|
||||
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false}
|
||||
ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false}
|
||||
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: ${ENABLE_WORKFLOW_RUN_CLEANUP_TASK:-false}
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false}
|
||||
ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false}
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}
|
||||
@@ -703,7 +704,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -745,7 +746,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -784,7 +785,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -814,7 +815,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.11.2
|
||||
image: langgenius/dify-web:1.11.4
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -31,6 +31,8 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
|
||||
|
||||
# The timeout for the text generation in millisecond
|
||||
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
|
||||
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only)
|
||||
TEXT_GENERATION_TIMEOUT_MS=60000
|
||||
|
||||
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
|
||||
NEXT_PUBLIC_CSP_WHITELIST=
|
||||
|
||||
1
web/.nvmrc
Normal file
1
web/.nvmrc
Normal file
@@ -0,0 +1 @@
|
||||
24
|
||||
@@ -1,8 +1,10 @@
|
||||
import type { Preview } from '@storybook/react'
|
||||
import type { Resource } from 'i18next'
|
||||
import { withThemeByDataAttribute } from '@storybook/addon-themes'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { ToastProvider } from '../app/components/base/toast'
|
||||
import I18N from '../app/components/i18n'
|
||||
import { I18nClientProvider as I18N } from '../app/components/provider/i18n'
|
||||
import commonEnUS from '../i18n/en-US/common.json'
|
||||
|
||||
import '../app/styles/globals.css'
|
||||
import '../app/styles/markdown.scss'
|
||||
@@ -16,6 +18,14 @@ const queryClient = new QueryClient({
|
||||
},
|
||||
})
|
||||
|
||||
const storyResources: Resource = {
|
||||
'en-US': {
|
||||
// Preload the most common namespace to avoid missing keys during initial render;
|
||||
// other namespaces will be loaded on demand via resourcesToBackend.
|
||||
common: commonEnUS as unknown as Record<string, unknown>,
|
||||
},
|
||||
}
|
||||
|
||||
export const decorators = [
|
||||
withThemeByDataAttribute({
|
||||
themes: {
|
||||
@@ -28,7 +38,7 @@ export const decorators = [
|
||||
(Story) => {
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<I18N locale="en-US">
|
||||
<I18N locale="en-US" resource={storyResources}>
|
||||
<ToastProvider>
|
||||
<Story />
|
||||
</ToastProvider>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# base image
|
||||
FROM node:22-alpine3.21 AS base
|
||||
FROM node:24-alpine AS base
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
|
||||
@@ -8,8 +8,18 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
|
||||
|
||||
Before starting the web frontend service, please make sure the following environment is ready.
|
||||
|
||||
- [Node.js](https://nodejs.org) >= v22.11.x
|
||||
- [pnpm](https://pnpm.io) v10.x
|
||||
- [Node.js](https://nodejs.org)
|
||||
- [pnpm](https://pnpm.io)
|
||||
|
||||
> [!TIP]
|
||||
> It is recommended to install and enable Corepack to manage package manager versions automatically:
|
||||
>
|
||||
> ```bash
|
||||
> npm install -g corepack
|
||||
> corepack enable
|
||||
> ```
|
||||
>
|
||||
> Learn more: [Corepack](https://github.com/nodejs/corepack#readme)
|
||||
|
||||
First, install the dependencies:
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => {
|
||||
)
|
||||
return {
|
||||
useGlobalPublicStore,
|
||||
useIsSystemFeaturesPending: () => false,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ export default function CheckCode() {
|
||||
setIsLoading(true)
|
||||
const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token })
|
||||
if (ret.result === 'success') {
|
||||
setWebAppAccessToken(ret.data.access_token)
|
||||
if (ret?.data?.access_token) {
|
||||
setWebAppAccessToken(ret.data.access_token)
|
||||
}
|
||||
const { access_token } = await fetchAccessToken({
|
||||
appCode: appCode!,
|
||||
userId: embeddedUserId || undefined,
|
||||
|
||||
@@ -82,7 +82,9 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
||||
body: loginData,
|
||||
})
|
||||
if (res.result === 'success') {
|
||||
setWebAppAccessToken(res.data.access_token)
|
||||
if (res?.data?.access_token) {
|
||||
setWebAppAccessToken(res.data.access_token)
|
||||
}
|
||||
|
||||
const { access_token } = await fetchAccessToken({
|
||||
appCode: appCode!,
|
||||
|
||||
@@ -9,8 +9,8 @@ import {
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
|
||||
} from '@/app/education-apply/constants'
|
||||
import { fetchSetupStatus } from '@/service/common'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
|
||||
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
|
||||
import { trackEvent } from './base/amplitude'
|
||||
|
||||
@@ -33,15 +33,8 @@ export const AppInitializer = ({
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
try {
|
||||
if (localStorage.getItem('setup_status') === 'finished')
|
||||
return true
|
||||
const setUpStatus = await fetchSetupStatus()
|
||||
if (setUpStatus.step !== 'finished') {
|
||||
localStorage.removeItem('setup_status')
|
||||
return false
|
||||
}
|
||||
localStorage.setItem('setup_status', 'finished')
|
||||
return true
|
||||
const setUpStatus = await fetchSetupStatusWithCache()
|
||||
return setUpStatus.step === 'finished'
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
|
||||
@@ -125,7 +125,6 @@ const resetAccessControlStore = () => {
|
||||
const resetGlobalStore = () => {
|
||||
useGlobalPublicStore.setState({
|
||||
systemFeatures: defaultSystemFeatures,
|
||||
isGlobalPending: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
||||
if (!isJsonObject || !tempPayload.json_schema)
|
||||
return ''
|
||||
try {
|
||||
return JSON.stringify(JSON.parse(tempPayload.json_schema), null, 2)
|
||||
return tempPayload.json_schema
|
||||
}
|
||||
catch {
|
||||
return ''
|
||||
|
||||
@@ -29,6 +29,7 @@ import { CheckModal } from '@/hooks/use-pay'
|
||||
import { useInfiniteAppList } from '@/service/use-apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { isServer } from '@/utils/client'
|
||||
import AppCard from './app-card'
|
||||
import { AppCardSkeleton } from './app-card-skeleton'
|
||||
import Empty from './empty'
|
||||
@@ -71,7 +72,7 @@ const List = () => {
|
||||
// 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all
|
||||
useEffect(() => {
|
||||
// avoid running on server
|
||||
if (typeof window === 'undefined')
|
||||
if (isServer)
|
||||
return
|
||||
const mode = searchParams.get('mode')
|
||||
if (!mode)
|
||||
|
||||
@@ -54,7 +54,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => {
|
||||
}
|
||||
|
||||
const AmplitudeProvider: FC<IAmplitudeProps> = ({
|
||||
sessionReplaySampleRate = 1,
|
||||
sessionReplaySampleRate = 0.5,
|
||||
}) => {
|
||||
useEffect(() => {
|
||||
// Only enable in Saas edition with valid API key
|
||||
|
||||
@@ -37,7 +37,7 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
|
||||
return
|
||||
}
|
||||
|
||||
if (!inputValue)
|
||||
if (inputValue == null)
|
||||
return
|
||||
|
||||
if (item.type === InputVarType.singleFile) {
|
||||
@@ -52,6 +52,20 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
|
||||
else
|
||||
processedInputs[item.variable] = getProcessedFiles(inputValue)
|
||||
}
|
||||
else if (item.type === InputVarType.jsonObject) {
|
||||
// Prefer sending an object if the user entered valid JSON; otherwise keep the raw string.
|
||||
try {
|
||||
const v = typeof inputValue === 'string' ? JSON.parse(inputValue) : inputValue
|
||||
if (v && typeof v === 'object' && !Array.isArray(v))
|
||||
processedInputs[item.variable] = v
|
||||
else
|
||||
processedInputs[item.variable] = inputValue
|
||||
}
|
||||
catch {
|
||||
// keep original string; backend will parse/validate
|
||||
processedInputs[item.variable] = inputValue
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return processedInputs
|
||||
|
||||
@@ -11,6 +11,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { isClient } from '@/utils/client'
|
||||
import {
|
||||
useEmbeddedChatbotContext,
|
||||
} from '../context'
|
||||
@@ -40,7 +41,6 @@ const Header: FC<IHeaderProps> = ({
|
||||
allInputsHidden,
|
||||
} = useEmbeddedChatbotContext()
|
||||
|
||||
const isClient = typeof window !== 'undefined'
|
||||
const isIframe = isClient ? window.self !== window.top : false
|
||||
const [parentOrigin, setParentOrigin] = useState('')
|
||||
const [showToggleExpandButton, setShowToggleExpandButton] = useState(false)
|
||||
|
||||
@@ -362,6 +362,18 @@ describe('PreviewDocumentPicker', () => {
|
||||
expect(screen.getByText('--')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render when value prop is omitted (optional)', () => {
|
||||
const files = createMockDocumentList(2)
|
||||
const onChange = vi.fn()
|
||||
// Do not pass `value` at all to verify optional behavior
|
||||
render(<PreviewDocumentPicker files={files} onChange={onChange} />)
|
||||
|
||||
// Renders placeholder for missing name
|
||||
expect(screen.getByText('--')).toBeInTheDocument()
|
||||
// Portal wrapper renders
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty files array', () => {
|
||||
renderComponent({ files: [] })
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import DocumentList from './document-list'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
value: DocumentItem
|
||||
value?: DocumentItem
|
||||
files: DocumentItem[]
|
||||
onChange: (value: DocumentItem) => void
|
||||
}
|
||||
@@ -30,7 +30,8 @@ const PreviewDocumentPicker: FC<Props> = ({
|
||||
onChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { name, extension } = value
|
||||
const name = value?.name || ''
|
||||
const extension = value?.extension
|
||||
|
||||
const [open, {
|
||||
set: setOpen,
|
||||
|
||||
3
web/app/components/provider/serwist.tsx
Normal file
3
web/app/components/provider/serwist.tsx
Normal file
@@ -0,0 +1,3 @@
|
||||
'use client'
|
||||
|
||||
export { SerwistProvider } from '@serwist/turbopack/react'
|
||||
@@ -195,7 +195,7 @@ const RunOnce: FC<IRunOnceProps> = ({
|
||||
noWrapper
|
||||
className="bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1"
|
||||
placeholder={
|
||||
<div className="whitespace-pre">{item.json_schema}</div>
|
||||
<div className="whitespace-pre">{typeof item.json_schema === 'string' ? item.json_schema : JSON.stringify(item.json_schema || '', null, 2)}</div>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -13,6 +13,7 @@ import Tooltip from '@/app/components/base/tooltip'
|
||||
import InstallFromMarketplace from '@/app/components/plugins/install-plugin/install-from-marketplace'
|
||||
import Action from '@/app/components/workflow/block-selector/market-place-plugin/action'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { isServer } from '@/utils/client'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import BlockIcon from '../block-icon'
|
||||
@@ -49,14 +50,14 @@ const FeaturedTools = ({
|
||||
const language = useGetLanguage()
|
||||
const [visibleCount, setVisibleCount] = useState(INITIAL_VISIBLE_COUNT)
|
||||
const [isCollapsed, setIsCollapsed] = useState<boolean>(() => {
|
||||
if (typeof window === 'undefined')
|
||||
if (isServer)
|
||||
return false
|
||||
const stored = window.localStorage.getItem(STORAGE_KEY)
|
||||
return stored === 'true'
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window === 'undefined')
|
||||
if (isServer)
|
||||
return
|
||||
const stored = window.localStorage.getItem(STORAGE_KEY)
|
||||
if (stored !== null)
|
||||
@@ -64,7 +65,7 @@ const FeaturedTools = ({
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window === 'undefined')
|
||||
if (isServer)
|
||||
return
|
||||
window.localStorage.setItem(STORAGE_KEY, String(isCollapsed))
|
||||
}, [isCollapsed])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user