Compare commits

..

121 Commits

Author SHA1 Message Date
JzoNg
81baeae5c4 fix(web): evaluation workflow switch 2026-04-03 18:22:44 +08:00
JzoNg
a3010bdc0b Merge branch 'main' into jzh 2026-04-03 18:05:54 +08:00
Coding On Star
83d4176785 test: add unit tests for app store and annotation components, enhancing coverage for state management and UI interactions (#34510)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 09:09:59 +00:00
JzoNg
8133e550ed chore: fix pre-hook of web 2026-04-03 16:21:32 +08:00
JzoNg
2bb0eab636 chore(web): mapping row refactor 2026-04-03 16:10:41 +08:00
JzoNg
5311b5d00d feat(web): available evaluation workflow selector 2026-04-03 16:06:33 +08:00
JzoNg
9b02ccdd12 Merge branch 'main' into jzh 2026-04-03 15:15:11 +08:00
JzoNg
231783eebe chore(web): fix lint 2026-04-03 15:13:52 +08:00
yyh
c94951b2f8 refactor(web): migrate notion page selectors to tanstack virtual (#34508)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 07:03:12 +00:00
Matt Van Horn
a9cf8f6c5d refactor(web): replace react-syntax-highlighter with shiki (#33473)
Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 06:40:26 +00:00
JzoNg
756606f478 feat(web): hide card view in evaluation 2026-04-03 14:39:41 +08:00
YBoy
64ddec0d67 refactor(api): type annotation service dicts with TypedDict (#34482)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-03 06:25:52 +00:00
JzoNg
6651c1c5da feat(web): workflow switch 2026-04-03 14:22:50 +08:00
Renzo
da3b0caf5e refactor: select in account_service (RegisterService class) (#34500)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 06:21:26 +00:00
JzoNg
61e257b2a8 feat(web): app switch api 2026-04-03 13:56:00 +08:00
Stephen Zhou
4fedd43af5 chore: update code-inspector-plugin to 1.5.1 (#34506) 2026-04-03 05:34:03 +00:00
yyh
a263f28e19 fix(web): restore ui select public exports (#34501) 2026-04-03 04:42:02 +00:00
Stephen Zhou
d53862f135 chore: override lodash (#34502) 2026-04-03 04:40:46 +00:00
Renzo
608958de1c refactor: select in external_knowledge_service (#34493)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 03:42:16 +00:00
Renzo
7eb632eb34 refactor: select in rag_pipeline (#34495)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 03:42:01 +00:00
Renzo
33d4fd357c refactor: select in account_service (AccountService class) (#34496) 2026-04-03 03:41:46 +00:00
agenthaulk
e55bd61c17 refactor: replace useContext with use in selected batch (#34450) 2026-04-03 03:37:35 +00:00
JzoNg
3ac4caf735 Merge branch 'main' into jzh 2026-04-03 11:28:22 +08:00
Stephen Zhou
f2fc213d52 chore: update deps (#34487) 2026-04-03 03:26:49 +00:00
YBoy
f814579ed2 test: migrate service_api dataset controller tests to testcontainers (#34423)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 02:28:47 +00:00
YBoy
71d299d0d3 refactor(api): type hit testing retrieve responses with TypedDict (#34484) 2026-04-03 02:25:30 +00:00
YBoy
e178451d04 refactor(api): type log identity dict with IdentityDict TypedDict (#34485) 2026-04-03 02:25:02 +00:00
YBoy
9a6222f245 refactor(api): type webhook data extraction with RawWebhookDataDict TypedDict (#34486) 2026-04-03 02:24:17 +00:00
YBoy
affe5ed30b refactor(api): type get_knowledge_rate_limit with KnowledgeRateLimitD… (#34483) 2026-04-03 02:23:32 +00:00
wangxiaolei
4cc5401d7e fix: fix import dsl failed (#34492)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 02:08:21 +00:00
Stephen Zhou
36e840cd87 chore: knip fix (#34481)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 15:03:42 +00:00
Tim Ren
985b41c40b fix(security): add tenant_id validation to prevent IDOR in data source binding (#34456)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 13:17:02 +00:00
lif
2e29ac2829 fix: remove redundant cast in MCP base session (#34461)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-04-02 12:36:21 +00:00
Renzo
dbfb474eab refactor: select in workflow_tools_manage_service (#34477) 2026-04-02 12:35:04 +00:00
Renzo
d243de26ec refactor: select in metadata_service (#34479) 2026-04-02 12:34:38 +00:00
Stephen Zhou
894826771a chore: clean up useless tailwind reference (#34478) 2026-04-02 11:45:19 +00:00
Asuka Minato
a3386da5d6 ci: Update pyrefly version to 0.59.1 (#34452)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 09:48:46 +00:00
99
318a3d0308 refactor(api): tighten login and wrapper typing (#34447) 2026-04-02 09:36:58 +00:00
JzoNg
268ae1751d Merge branch 'main' into jzh 2026-04-01 09:26:13 +08:00
JzoNg
015cbf850b Merge branch 'main' into jzh 2026-03-31 18:08:24 +08:00
JzoNg
873e13c2fb feat(web): support select node in metric card 2026-03-31 18:07:52 +08:00
JzoNg
688bf7e7a1 feat(web): metric card style 2026-03-31 17:43:56 +08:00
JzoNg
a6ffff3b39 fix(web): fix style of metric selector 2026-03-31 17:22:07 +08:00
JzoNg
023fc55bd5 fix(web): empty state of metric 2026-03-31 17:11:44 +08:00
JzoNg
351b909a53 feat(web): metric card 2026-03-31 17:00:37 +08:00
JzoNg
6bec4f65c9 refactor(web): metric section refactor 2026-03-31 16:28:48 +08:00
JzoNg
74f87ce152 Merge branch 'main' into jzh 2026-03-31 16:13:04 +08:00
JzoNg
92c472ccc7 Merge branch 'main' into jzh 2026-03-30 15:40:23 +08:00
JzoNg
b92b8becd1 feat(web): metric selector 2026-03-30 15:39:52 +08:00
JzoNg
23d0d6a65d chore(web): i18n of metrics 2026-03-30 14:20:43 +08:00
JzoNg
1660067d6e feat(web): judgement model selector 2026-03-30 14:03:37 +08:00
JzoNg
0642475b85 Merge branch 'main' into jzh 2026-03-30 13:30:10 +08:00
JzoNg
8cb634c9bc feat(web): evaluation layout 2026-03-30 11:27:06 +08:00
JzoNg
768b41c3cf Merge branch 'main' into jzh 2026-03-30 11:07:42 +08:00
JzoNg
ca88516d54 refactor(web): refactor evaluation page 2026-03-30 11:06:41 +08:00
JzoNg
871a2a149f refactor(web): split snippet index 2026-03-30 10:32:59 +08:00
JzoNg
60e381eff0 Merge branch 'main' into jzh 2026-03-30 09:48:58 +08:00
JzoNg
768b3eb6f9 feat(web): test run of snippet 2026-03-29 20:55:11 +08:00
JzoNg
2f88da4a6d feat(web): add variable inspect for snippet 2026-03-29 20:23:24 +08:00
JzoNg
a8cdf6964c feat(web): test run button 2026-03-29 20:02:59 +08:00
JzoNg
985c3db4fd feat(web): snippet input field panel layout 2026-03-29 18:02:27 +08:00
JzoNg
9636472db7 refactor(web): snippet main 2026-03-29 17:50:30 +08:00
JzoNg
0ad268aa7d feat(web): snippet publish 2026-03-29 17:29:37 +08:00
JzoNg
a4ea33167d feat(web): block selector in snippet 2026-03-29 17:01:32 +08:00
JzoNg
0f13aabea8 feat(web): input fields in snippet 2026-03-29 16:31:38 +08:00
JzoNg
1e76ef5ccb chore(web): ignore system vars & conversation vars in rag-pipeline and snippet 2026-03-29 15:56:24 +08:00
JzoNg
e6e3229d17 feat(web): input field button style 2026-03-29 15:45:05 +08:00
JzoNg
dccf8e723a feat(web): snippet version panel 2026-03-29 15:26:59 +08:00
JzoNg
c41ba7d627 feat(web): snippet header in graph 2026-03-29 15:02:34 +08:00
JzoNg
a6e9316de3 Merge branch 'main' into jzh 2026-03-29 14:07:49 +08:00
JzoNg
559d326cbd chore(web): mock data of snippet 2026-03-27 17:24:01 +08:00
JzoNg
abedf2506f Merge branch 'main' into jzh 2026-03-27 17:01:27 +08:00
JzoNg
d01428b5bc feat(web): snippet graph draft sync 2026-03-27 16:02:47 +08:00
JzoNg
0de1f17e5c Merge branch 'main' into jzh 2026-03-27 15:23:49 +08:00
JzoNg
17d07a5a43 feat(web): init snippet graph 2026-03-27 15:23:03 +08:00
JzoNg
3bdbea99a3 Merge branch 'main' into jzh 2026-03-27 14:04:10 +08:00
JzoNg
b7683aedb1 Merge branch 'main' into jzh 2026-03-26 21:38:48 +08:00
JzoNg
515036e758 test(web): add tests for snippets 2026-03-26 21:38:22 +08:00
JzoNg
22b382527f feat(web): add snippet to workflow 2026-03-26 21:26:29 +08:00
JzoNg
2cfe4b5b86 feat(web): snippet graph data fetching 2026-03-26 21:11:09 +08:00
JzoNg
6876c8041c feat(web): snippet list data fetching in block selector 2026-03-26 20:58:42 +08:00
JzoNg
7de45584ce refactor: snippets list 2026-03-26 20:41:51 +08:00
JzoNg
5572d7c7e8 Merge branch 'main' into jzh 2026-03-26 20:10:47 +08:00
JzoNg
db0a2fe52e Merge branch 'main' into jzh 2026-03-26 16:29:44 +08:00
JzoNg
f0ae8d6167 fix(web): unused imports caused by merge 2026-03-26 16:28:56 +08:00
JzoNg
2514e181ba Merge branch 'main' into jzh 2026-03-26 16:16:10 +08:00
JzoNg
be2e6e9a14 Merge branch 'main' into jzh 2026-03-26 14:23:29 +08:00
JzoNg
875e2eac1b Merge branch 'main' into jzh 2026-03-26 08:38:57 +08:00
JzoNg
c3c73ceb1f Merge branch 'main' into jzh 2026-03-25 23:02:18 +08:00
JzoNg
6318bf0a2a feat(web): create snippet from workflow 2026-03-25 22:57:48 +08:00
JzoNg
5e1f252046 feat(web): selection context menu style update 2026-03-25 22:36:27 +08:00
JzoNg
df3b960505 fix(web): position of selection context menu in workflow graph 2026-03-25 22:02:50 +08:00
JzoNg
26bc108bf1 chore(web): tests for snippet info 2026-03-25 21:35:36 +08:00
JzoNg
a5cff32743 feat(web): snippet info operations 2026-03-25 21:29:06 +08:00
JzoNg
d418dd8eec Merge branch 'main' into jzh 2026-03-25 20:17:32 +08:00
JzoNg
61702fe346 Merge branch 'main' into jzh 2026-03-25 18:17:03 +08:00
JzoNg
43f0c780c3 Merge branch 'main' into jzh 2026-03-25 15:30:21 +08:00
JzoNg
30ebf2bfa9 Merge branch 'main' into jzh 2026-03-24 07:25:22 +08:00
JzoNg
7e3027b5f7 feat(web): snippet card usage info 2026-03-23 17:02:00 +08:00
JzoNg
b3acf83090 Merge branch 'main' into jzh 2026-03-23 16:46:26 +08:00
JzoNg
36c3d6e48a feat(web): snippet list fetching & display 2026-03-23 16:37:05 +08:00
JzoNg
f782ac6b3c feat(web): create snippets by DSL import 2026-03-23 14:55:36 +08:00
JzoNg
feef2dd1fa feat(web): add snippet creation dialog flow 2026-03-23 11:29:41 +08:00
JzoNg
a716d8789d refactor: extract snippet list components 2026-03-23 10:48:15 +08:00
JzoNg
6816f89189 Merge branch 'main' into jzh 2026-03-23 10:13:45 +08:00
JzoNg
bfcac64a9d Merge branch 'main' into jzh 2026-03-20 15:33:49 +08:00
JzoNg
664eb601a2 feat(web): add api of snippet worfklows 2026-03-20 15:29:53 +08:00
JzoNg
8e5cc4e0aa feat(web): add evaluation api 2026-03-20 15:23:03 +08:00
JzoNg
9f28575903 feat(web): add snippets api 2026-03-20 15:11:33 +08:00
JzoNg
4b9a26a5e6 Merge branch 'main' into jzh 2026-03-20 14:01:34 +08:00
JzoNg
7b85adf1cc Merge branch 'main' into jzh 2026-03-20 10:46:45 +08:00
JzoNg
c964708ebe Merge branch 'main' into jzh 2026-03-18 18:07:20 +08:00
JzoNg
883eb498c0 Merge branch 'main' into jzh 2026-03-18 17:40:51 +08:00
JzoNg
4d3738d225 Merge branch 'main' into feat/evaluation-fe 2026-03-17 10:42:44 +08:00
JzoNg
dd0dee739d Merge branch 'main' into jzh 2026-03-16 15:43:20 +08:00
zxhlyh
4d19914fcb Merge branch 'main' into feat/evaluation-fe 2026-03-16 10:47:37 +08:00
zxhlyh
887c7710e9 feat: evaluation 2026-03-16 10:46:33 +08:00
zxhlyh
7a722773c7 feat: snippet canvas 2026-03-13 17:45:04 +08:00
zxhlyh
a763aff58b feat: snippets list 2026-03-13 16:12:42 +08:00
zxhlyh
c1011f4e5c feat: add to snippet 2026-03-13 14:29:59 +08:00
zxhlyh
f7afa103a5 feat: select snippets 2026-03-13 13:43:29 +08:00
837 changed files with 34657 additions and 14094 deletions

View File

@@ -77,7 +77,7 @@ if $web_modified; then
fi
cd ./web || exit 1
vp staged
pnpm exec vp staged
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
@@ -89,6 +89,12 @@ if $web_modified; then
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! pnpm run knip; then
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
exit 1
fi
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)

View File

@@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = current_user

View File

@@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: Any, **kwargs: Any):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
return wrapper

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import overload
from sqlalchemy import select
@@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
def get_app_model(
view: Callable[..., Any] | None = None,
@overload
def get_app_model[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@@ -68,14 +84,30 @@ def get_app_model(
return decorator(view)
def get_app_model_with_trial(
view: Callable[..., Any] | None = None,
@overload
def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@@ -1,6 +1,4 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@@ -11,7 +9,6 @@ from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -87,39 +84,3 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@@ -158,10 +158,11 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")

View File

@@ -1,4 +1,5 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from flask import Response, request
@@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite(f):
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@@ -70,7 +71,7 @@ def _api_prerequisite(f):
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)

View File

@@ -1,9 +1,10 @@
import inspect
import logging
import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Any, cast, overload
from typing import cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
return interceptor
def validate_dataset_token(
view: Callable[..., Any] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset")
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
positional_parameters = [
parameter
for parameter in inspect.signature(view).parameters.values()
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
@wraps(view)
def decorated(*args: object, **kwargs: object) -> R:
api_token = validate_and_get_api_token("dataset")
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
if not dataset_id and args:
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.limit(1)
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant does not exist.")
if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs)
if expects_bound_instance:
if not args:
raise TypeError("validate_dataset_token expected a bound resource instance.")
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
return decorated
return view(api_token.tenant_id, *args, **kwargs)
if view:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
return decorated
def validate_and_get_api_token(scope: str | None = None):

View File

@@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import WebhookService
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
logger = logging.getLogger(__name__)
@@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_id, is_debug=is_debug
)
webhook_data: RawWebhookDataDict
try:
# Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)

View File

@@ -3,13 +3,19 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
from typing import Any, TypedDict
import orjson
from configs import dify_config
class IdentityDict(TypedDict, total=False):
tenant_id: str
user_id: str
user_type: str
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
@@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
identity: IdentityDict = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Self, cast
from typing import Any, Self
from httpx import HTTPStatusError
from pydantic import BaseModel
@@ -338,12 +338,11 @@ class BaseSession[
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
@@ -359,15 +358,14 @@ class BaseSession[
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification)
self._handle_incoming(notification)
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)

View File

@@ -1,5 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from flask import Flask
if TYPE_CHECKING:
from extensions.ext_login import DifyLoginManager
class DifyApp(Flask):
pass
"""Flask application type with Dify-specific extension attributes."""
login_manager: DifyLoginManager

View File

@@ -1,17 +1,56 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@@ -19,3 +58,152 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@@ -1,7 +1,8 @@
import json
from typing import cast
import flask_login
from flask import Response, request
from flask import Request, Response, request
from flask_login import user_loaded_from_request, user_logged_in
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
@@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
type LoginUser = Account | EndUser
class DifyLoginManager(flask_login.LoginManager):
"""Project-specific Flask-Login manager with a stable unauthorized contract.
Dify registers `unauthorized_handler` below to always return a JSON `Response`.
Overriding this method lets callers rely on that narrower return type instead of
Flask-Login's broader callback contract.
"""
def unauthorized(self) -> Response:
"""Return the registered unauthorized handler result as a Flask `Response`."""
return cast(Response, super().unauthorized())
def load_user_from_request_context(self) -> None:
"""Populate Flask-Login's request-local user cache for the current request."""
self._load_user()
login_manager = DifyLoginManager()
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None:
"""Load user based on the request."""
del request_from_flask_login
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
@@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login):
raise NotFound("End user not found.")
return end_user
return None
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
def on_user_logged_in(_sender: object, user: LoginUser) -> None:
"""Called when a user logged in.
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
@@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user):
@login_manager.unauthorized_handler
def unauthorized_handler():
def unauthorized_handler() -> Response:
"""Handle unauthorized requests."""
# Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows
# Flask-Login's callback contract based on this override.
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
@@ -123,5 +150,5 @@ def unauthorized_handler():
)
def init_app(app: DifyApp):
def init_app(app: DifyApp) -> None:
login_manager.init_app(app)

View File

@@ -2,19 +2,19 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from flask import current_app, g, has_request_context, request
from flask import Response, current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_login import DifyLoginManager
from libs.token import check_csrf_token
from models import Account
if TYPE_CHECKING:
from flask.typing import ResponseReturnValue
from models.model import EndUser
@@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None:
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
def _get_login_manager() -> DifyLoginManager:
"""Return the project login manager with Dify's narrowed unauthorized contract."""
app = cast(DifyApp, current_app)
return app.login_manager
def current_account_with_tenant() -> tuple[Account, str]:
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
@@ -42,7 +48,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
"""
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# `DifyLoginManager` guarantees that the registered unauthorized handler
# is surfaced here as a concrete Flask `Response`.
unauthorized_response: Response = _get_login_manager().unauthorized()
return unauthorized_response
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
@@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
def _get_user() -> EndUser | Account | None:
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
_get_login_manager().load_user_from_request_context()
return g._login_user

View File

@@ -171,7 +171,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.57.1",
"pyrefly>=0.59.1",
]
############################################################

View File

@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
@@ -144,22 +144,26 @@ class AccountService:
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.query(Account).filter_by(id=user_id).first()
account = db.session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
)
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if not available_ta:
return None
@@ -195,7 +199,7 @@ class AccountService:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@@ -371,8 +375,10 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
account_integrate: AccountIntegrate | None = db.session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
)
if account_integrate:
@@ -416,7 +422,9 @@ class AccountService:
def update_account_email(account: Account, email: str) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
account_integrate = db.session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
@@ -818,7 +826,7 @@ class AccountService:
)
)
account = db.session.query(Account).where(Account.email == email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@@ -1018,7 +1026,7 @@ class AccountService:
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.query(Account).filter_by(email=email).first() is None
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:
@@ -1384,10 +1392,10 @@ class RegisterService:
db.session.add(dify_setup)
db.session.commit()
except Exception as e:
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()
logger.exception("Setup account failed, email: %s, name: %s", email, name)
@@ -1488,7 +1496,11 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta:
TenantService.create_tenant_member(tenant, account, role)
@@ -1545,21 +1557,18 @@ class RegisterService:
if not invitation_data:
return None
tenant = (
db.session.query(Tenant)
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
tenant = db.session.scalar(
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
)
if not tenant:
return None
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
tenant_account = db.session.execute(
select(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
).first()
if not tenant_account:
return None

View File

@@ -4,6 +4,8 @@ import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -23,6 +25,27 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
class AnnotationJobStatusDict(TypedDict):
job_id: str
job_status: str
class EmbeddingModelDict(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationSettingDict(TypedDict):
id: str
enabled: bool
score_threshold: float
embedding_model: EmbeddingModelDict | dict
class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
@@ -85,7 +108,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str):
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@@ -109,7 +132,7 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str):
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
_, current_tenant_id = current_account_with_tenant()
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
@@ -567,7 +590,7 @@ class AppAnnotationService:
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
@@ -602,7 +625,9 @@ class AppAnnotationService:
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (

View File

@@ -18,13 +18,12 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@@ -107,7 +106,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@@ -117,7 +116,6 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(

View File

@@ -22,7 +22,6 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@@ -132,10 +131,9 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
# 7. Check and consume quota
try:
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@@ -155,18 +153,13 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED

View File

@@ -2,7 +2,7 @@ import json
import logging
import os
from collections.abc import Sequence
from typing import Literal, NotRequired, TypedDict
from typing import Literal, TypedDict
import httpx
from pydantic import TypeAdapter
@@ -32,79 +32,9 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _BillingQuota(TypedDict):
size: int
class KnowledgeRateLimitDict(TypedDict):
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
subscription_plan: str
class BillingService:
@@ -119,71 +49,21 @@ class BillingService:
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str) -> BillingInfo:
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return _billing_info_adapter.validate_python(billing_info)
return billing_info
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
return cls._send_request("GET", "/quota/info", params=params)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
import httpx
from graphon.nodes.http_request.exc import InvalidHttpMethodError
from sqlalchemy import select
from sqlalchemy import func, select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@@ -103,8 +103,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -112,8 +114,10 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -132,8 +136,10 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -144,9 +150,12 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
)
)
or 0
)
if count > 0:
return True, count
@@ -154,8 +163,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
@@ -163,8 +174,10 @@ class ExternalDatasetService:
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("api template not found")
@@ -238,12 +251,17 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
):
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None:
@@ -286,16 +304,18 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("external api template not found")

View File

@@ -281,7 +281,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
@@ -312,10 +312,7 @@ class FeatureService:
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
# but LimitationModel.size is int; truncate here for compatibility
features.vector_space.size = int(billing_info["vector_space"]["size"])
# NOTE END
features.vector_space.size = billing_info["vector_space"]["size"]
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
@@ -336,11 +333,7 @@ class FeatureService:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
# NOTE (hj24):
# 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used.
# 2. So be careful if later we decide to use [size], we cannot assume it is always present.
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
# NOTE END
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]

View File

@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities import LLMMode
@@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
class QueryDict(TypedDict):
content: str
class RetrieveResponseDict(TypedDict):
query: QueryDict
records: list[dict[str, Any]]
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
@@ -150,7 +160,7 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
records = RetrievalService.format_retrieval_documents(documents)
return {
@@ -161,7 +171,7 @@ class HitTestingService:
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict:
records = []
if dataset.provider == "external":
for document in documents:

View File

@@ -1,6 +1,8 @@
import copy
import logging
from sqlalchemy import delete, func, select
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -25,10 +27,14 @@ class MetadataService:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == metadata_args.name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@@ -54,10 +60,14 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@@ -65,7 +75,11 @@ class MetadataService:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@@ -74,9 +88,9 @@ class MetadataService:
metadata.updated_at = naive_utc_now()
# update related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@@ -101,15 +115,19 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@@ -224,16 +242,23 @@ class MetadataService:
# deal metadata binding (in the same transaction as the doc_metadata update)
if not operation.partial_update:
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
db.session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.document_id == operation.document_id
)
)
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
# check if binding already exists
if operation.partial_update:
existing_binding = (
db.session.query(DatasetMetadataBinding)
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
.first()
existing_binding = db.session.scalar(
select(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.document_id == operation.document_id,
DatasetMetadataBinding.metadata_id == metadata_value.id,
)
.limit(1)
)
if existing_binding:
continue
@@ -275,9 +300,13 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
"count": db.session.scalar(
select(func.count(DatasetMetadataBinding.id)).where(
DatasetMetadataBinding.metadata_id == item.get("id"),
DatasetMetadataBinding.dataset_id == dataset.id,
)
)
or 0,
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@@ -1,233 +0,0 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@@ -156,27 +156,27 @@ class RagPipelineService:
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
@@ -192,13 +192,13 @@ class RagPipelineService:
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
@@ -210,14 +210,14 @@ class RagPipelineService:
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
.limit(1)
)
# return draft workflow
@@ -232,28 +232,28 @@ class RagPipelineService:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
.limit(1)
)
return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations."""
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
@@ -974,7 +974,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
document = db.session.get(Document, document_id.value)
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = error
@@ -1178,12 +1178,12 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
@@ -1194,21 +1194,21 @@ class RagPipelineService:
# check template name is exist
template_name = args.get("name")
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
max_position = db.session.scalar(
select(func.max(PipelineCustomizedTemplate.position)).where(
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
)
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@@ -1239,13 +1239,14 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
db.session.scalar(
select(func.count(Workflow.id)).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
)
.count()
or 0
) > 0
def get_node_last_run(
@@ -1353,11 +1354,11 @@ class RagPipelineService:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
if not pipeline_recommended_plugins:
return {
@@ -1396,14 +1397,12 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_execution_log = (
db.session.query(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
document_pipeline_execution_log = db.session.scalar(
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
)
if not document_pipeline_execution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@@ -1432,23 +1431,23 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")
@@ -1530,23 +1529,23 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")

View File

@@ -3,7 +3,7 @@ import logging
from datetime import datetime
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from core.tools.__base.tool_provider import ToolProviderController
@@ -42,20 +42,22 @@ class WorkflowToolManageService:
labels: list[str] | None = None,
):
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
@@ -122,30 +124,30 @@ class WorkflowToolManageService:
:return: the updated tool
"""
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
@@ -234,9 +236,11 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
db.session.commit()
@@ -251,10 +255,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -267,10 +271,10 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -284,8 +288,8 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
workflow_app: App | None = db.session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
if workflow_app is None:
@@ -331,10 +335,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if db_tool is None:

View File

@@ -3,7 +3,7 @@ import logging
import mimetypes
import secrets
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, TypedDict
import orjson
from flask import request
@@ -38,7 +38,6 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@@ -51,6 +50,14 @@ logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class RawWebhookDataDict(TypedDict):
method: str
headers: dict[str, str]
query_params: dict[str, str]
body: dict[str, Any]
files: dict[str, Any]
class WebhookService:
"""Service for handling webhook operations."""
@@ -146,7 +153,7 @@ class WebhookService:
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
) -> dict[str, Any]:
) -> RawWebhookDataDict:
"""Extract and validate webhook data in a single unified process.
Args:
@@ -174,7 +181,7 @@ class WebhookService:
return processed_data
@classmethod
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict:
"""Extract raw data from incoming webhook request without type conversion.
Args:
@@ -190,7 +197,7 @@ class WebhookService:
"""
cls._validate_content_length()
data = {
data: RawWebhookDataDict = {
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
@@ -224,7 +231,7 @@ class WebhookService:
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict:
"""Process and validate webhook data according to node configuration.
Args:
@@ -665,7 +672,7 @@ class WebhookService:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> dict[str, Any]:
"""Validate HTTP method and content-type.
Args:
@@ -730,7 +737,7 @@ class WebhookService:
return False
@classmethod
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> dict[str, Any]:
"""Construct workflow inputs payload from webhook data.
Args:
@@ -748,7 +755,7 @@ class WebhookService:
@classmethod
def trigger_workflow_execution(
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow
) -> None:
"""Trigger workflow execution via AsyncWorkflowService.
@@ -783,9 +790,9 @@ class WebhookService:
user_id=None,
)
# reserve quota before triggering workflow execution
# consume quota before triggering workflow execution
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
@@ -796,16 +803,11 @@ class WebhookService:
raise
# Trigger workflow execution asynchronously
try:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@@ -28,7 +28,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from models.enums import (
AppTriggerType,
CreatorUserRole,
@@ -42,7 +42,6 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
@@ -299,10 +298,10 @@ def dispatch_triggered_workflow(
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
# reserve quota before invoking trigger
# consume quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
@@ -388,7 +387,6 @@ def dispatch_triggered_workflow(
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
quota_charge.commit()
dispatched_count += 1
logger.info(
"Triggered workflow for app %s with trigger event %s",

View File

@@ -8,11 +8,10 @@ from core.workflow.nodes.trigger_schedule.exc import (
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType
from enums.quota_type import QuotaType, unlimited
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService, unlimited
from services.trigger.app_trigger_service import AppTriggerService
from services.trigger.schedule_service import ScheduleService
from services.workflow.entities import ScheduleTriggerData
@@ -44,7 +43,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
@@ -62,7 +61,6 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()

View File

@@ -36,19 +36,12 @@ class TestAppGenerateService:
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.quota_reserve.return_value = {
"reservation_id": "test-reservation-id",
"available": 100,
"reserved": 1,
}
mock_billing_service.quota_commit.return_value = {
"available": 99,
"reserved": 0,
"refunded": 0,
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
}
# Setup default mock returns for workflow service
@@ -108,8 +101,6 @@ class TestAppGenerateService:
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_quota_dify_config.BILLING_ENABLED = False
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
@@ -127,7 +118,6 @@ class TestAppGenerateService:
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"quota_dify_config": mock_quota_dify_config,
"global_dify_config": mock_global_dify_config,
}
@@ -475,7 +465,6 @@ class TestAppGenerateService:
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
@@ -489,10 +478,8 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
# Verify billing two-phase quota (reserve + commit)
billing = mock_external_service_dependencies["billing_service"]
billing.quota_reserve.assert_called_once()
billing.quota_commit.assert_called_once()
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log(
)
# Mock quota to avoid rate limiting
from services import quota_service
from enums import quota_type
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
# Execute schedule trigger
workflow_schedule_tasks.run_schedule_trigger(plan.id)

View File

@@ -20,7 +20,7 @@ def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(_load_user=lambda: None)
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return app

View File

@@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return flask_app

View File

@@ -1,349 +0,0 @@
"""Unit tests for QuotaType, QuotaService, and QuotaCharge."""
from unittest.mock import patch
import pytest
from enums.quota_type import QuotaType
from services.quota_service import QuotaCharge, QuotaService, unlimited
class TestQuotaType:
def test_billing_key_trigger(self):
assert QuotaType.TRIGGER.billing_key == "trigger_event"
def test_billing_key_workflow(self):
assert QuotaType.WORKFLOW.billing_key == "api_rate_limit"
def test_billing_key_unlimited_raises(self):
with pytest.raises(ValueError, match="Invalid quota type"):
_ = QuotaType.UNLIMITED.billing_key
class TestQuotaService:
def test_reserve_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService"),
):
mock_cfg.BILLING_ENABLED = False
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_reserve_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0)
def test_reserve_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99}
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1)
assert charge.success is True
assert charge.charge_id == "rid-1"
assert charge._tenant_id == "t1"
assert charge._feature_key == "trigger_event"
assert charge._amount == 1
mock_bs.quota_reserve.assert_called_once()
def test_reserve_no_reservation_id_raises(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {}
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_quota_exceeded_propagates(self):
from services.errors.app import QuotaExceededError
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1)
with pytest.raises(QuotaExceededError):
QuotaService.reserve(QuotaType.TRIGGER, "t1")
def test_reserve_api_exception_returns_unlimited(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.side_effect = RuntimeError("network")
charge = QuotaService.reserve(QuotaType.TRIGGER, "t1")
assert charge.success is True
assert charge.charge_id is None
def test_consume_calls_reserve_and_commit(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"}
mock_bs.quota_commit.return_value = {}
charge = QuotaService.consume(QuotaType.TRIGGER, "t1")
assert charge.success is True
mock_bs.quota_commit.assert_called_once()
def test_check_billing_disabled(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = False
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_check_zero_amount_raises(self):
with patch("services.quota_service.dify_config") as mock_cfg:
mock_cfg.BILLING_ENABLED = True
with pytest.raises(ValueError, match="greater than 0"):
QuotaService.check(QuotaType.TRIGGER, "t1", amount=0)
def test_check_sufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=100),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True
def test_check_insufficient_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=5),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False
def test_check_unlimited_quota(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", return_value=-1),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True
def test_check_exception_returns_true(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch.object(QuotaService, "get_remaining", side_effect=RuntimeError),
):
mock_cfg.BILLING_ENABLED = True
assert QuotaService.check(QuotaType.TRIGGER, "t1") is True
def test_release_billing_disabled(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = False
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_empty_reservation(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event")
mock_bs.quota_release.assert_not_called()
def test_release_success(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.return_value = {}
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
mock_bs.quota_release.assert_called_once_with(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1"
)
def test_release_exception_swallowed(self):
with (
patch("services.quota_service.dify_config") as mock_cfg,
patch("services.billing_service.BillingService") as mock_bs,
):
mock_cfg.BILLING_ENABLED = True
mock_bs.quota_release.side_effect = RuntimeError("fail")
QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_get_remaining_normal(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70
def test_get_remaining_unlimited(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_over_limit_returns_zero(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_exception_returns_neg1(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.side_effect = RuntimeError
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1
def test_get_remaining_empty_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_non_dict_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = "invalid"
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
def test_get_remaining_feature_not_in_response(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}}
remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1")
assert remaining == 0
def test_get_remaining_non_dict_feature_info(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"}
assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0
class TestQuotaCharge:
def test_commit_success(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
mock_bs.quota_commit.assert_called_once_with(
tenant_id="t1",
feature_key="trigger_event",
reservation_id="rid-1",
actual_amount=1,
)
assert charge._committed is True
def test_commit_with_actual_amount(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=10,
)
charge.commit(actual_amount=5)
call_kwargs = mock_bs.quota_commit.call_args[1]
assert call_kwargs["actual_amount"] == 5
def test_commit_idempotent(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.return_value = {}
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
charge.commit()
assert mock_bs.quota_commit.call_count == 1
def test_commit_no_charge_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_no_tenant_id_noop(self):
with patch("services.billing_service.BillingService") as mock_bs:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
_feature_key="trigger_event",
)
charge.commit()
mock_bs.quota_commit.assert_not_called()
def test_commit_exception_swallowed(self):
with patch("services.billing_service.BillingService") as mock_bs:
mock_bs.quota_commit.side_effect = RuntimeError("fail")
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
_amount=1,
)
charge.commit()
def test_refund_success(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id="t1",
_feature_key="trigger_event",
)
charge.refund()
mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event")
def test_refund_no_charge_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER)
charge.refund()
mock_rel.assert_not_called()
def test_refund_no_tenant_id_noop(self):
with patch.object(QuotaService, "release") as mock_rel:
charge = QuotaCharge(
success=True,
charge_id="rid-1",
_quota_type=QuotaType.TRIGGER,
_tenant_id=None,
)
charge.refund()
mock_rel.assert_not_called()
class TestUnlimited:
def test_unlimited_returns_success_with_no_charge_id(self):
charge = unlimited()
assert charge.success is True
assert charge.charge_id is None
assert charge._quota_type == QuotaType.UNLIMITED

View File

@@ -0,0 +1,17 @@
import json
from flask import Response
from extensions.ext_login import unauthorized_handler
def test_unauthorized_handler_returns_json_response() -> None:
response = unauthorized_handler()
assert isinstance(response, Response)
assert response.status_code == 401
assert response.content_type == "application/json"
assert json.loads(response.get_data(as_text=True)) == {
"code": "unauthorized",
"message": "Unauthorized.",
}

View File

@@ -2,11 +2,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask, g
from flask_login import LoginManager, UserMixin
from flask import Flask, Response, g
from flask_login import UserMixin
from pytest_mock import MockerFixture
import libs.login as login_module
from extensions.ext_login import DifyLoginManager
from libs.login import current_user
from models.account import Account
@@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
login_manager = LoginManager()
login_manager = DifyLoginManager()
login_manager.init_app(app)
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
login_manager.unauthorized = mocker.Mock(
name="unauthorized",
return_value=Response("Unauthorized", status=401, content_type="application/json"),
)
@login_manager.user_loader
def load_user(_user_id: str):
@@ -109,18 +113,43 @@ class TestLoginRequired:
resolved_user: MockUser | None,
description: str,
):
"""Test that missing or unauthenticated users are redirected."""
"""Test that missing or unauthenticated users return the manager response."""
resolve_user = resolve_current_user(resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result == "Unauthorized", description
assert result is login_app.login_manager.unauthorized.return_value, description
assert isinstance(result, Response)
assert result.status_code == 401
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
def test_unauthorized_access_propagates_response_object(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
mocker: MockerFixture,
) -> None:
"""Test that unauthorized responses are propagated as Flask Response objects."""
resolve_user = resolve_current_user(None)
response = Response("Unauthorized", status=401, content_type="application/json")
mocker.patch.object(
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
)
with login_app.test_request_context():
result = protected_view()
assert result is response
assert isinstance(result, Response)
resolve_user.assert_called_once_with()
csrf_check.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
[
@@ -168,10 +197,14 @@ class TestGetUser:
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
def _load_user() -> None:
def load_user_from_request_context() -> None:
g._login_user = mock_user
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
load_user = mocker.patch.object(
login_app.login_manager,
"load_user_from_request_context",
side_effect=load_user_from_request_context,
)
with login_app.test_request_context():
user = login_module._get_user()

View File

@@ -401,10 +401,7 @@ class TestMetadataServiceCreateMetadata:
metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
# Mock query to return None (no existing metadata with same name)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Mock BuiltInField enum iteration
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@@ -417,10 +414,6 @@ class TestMetadataServiceCreateMetadata:
assert result is not None
assert isinstance(result, DatasetMetadata)
# Verify query was made to check for duplicates
mock_db_session.query.assert_called()
mock_query.filter_by.assert_called()
# Verify metadata was added and committed
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
@@ -468,10 +461,7 @@ class TestMetadataServiceCreateMetadata:
# Mock existing metadata with same name
existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category")
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = existing_metadata
# Act & Assert
with pytest.raises(ValueError, match="Metadata name already exists"):
@@ -500,10 +490,7 @@ class TestMetadataServiceCreateMetadata:
)
# Mock query to return None (no duplicate in database)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Mock BuiltInField to include the conflicting name
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@@ -597,27 +584,11 @@ class TestMetadataServiceUpdateMetadataName:
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
# Mock query for duplicate check (no duplicate)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
# Mock metadata retrieval
def query_side_effect(model):
if model == DatasetMetadata:
mock_meta_query = Mock()
mock_meta_query.filter_by.return_value = mock_meta_query
mock_meta_query.first.return_value = existing_metadata
return mock_meta_query
return mock_query
mock_db_session.query.side_effect = query_side_effect
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval
mock_db_session.scalar.side_effect = [None, existing_metadata]
# Mock no metadata bindings (no documents to update)
mock_binding_query = Mock()
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
# Mock BuiltInField enum
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@@ -655,22 +626,8 @@ class TestMetadataServiceUpdateMetadataName:
metadata_id = "non-existent-metadata"
new_name = "updated_category"
# Mock query for duplicate check (no duplicate)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
# Mock metadata retrieval to return None
def query_side_effect(model):
if model == DatasetMetadata:
mock_meta_query = Mock()
mock_meta_query.filter_by.return_value = mock_meta_query
mock_meta_query.first.return_value = None # Not found
return mock_meta_query
return mock_query
mock_db_session.query.side_effect = query_side_effect
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval (None = not found)
mock_db_session.scalar.side_effect = [None, None]
# Mock BuiltInField enum
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@@ -746,15 +703,10 @@ class TestMetadataServiceDeleteMetadata:
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
# Mock metadata retrieval
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = existing_metadata
# Mock no metadata bindings (no documents to update)
mock_binding_query = Mock()
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
# Act
result = MetadataService.delete_metadata(dataset_id, metadata_id)
@@ -788,10 +740,7 @@ class TestMetadataServiceDeleteMetadata:
metadata_id = "non-existent-metadata"
# Mock metadata retrieval to return None
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="Metadata not found"):
@@ -1013,10 +962,7 @@ class TestMetadataServiceGetDatasetMetadatas:
)
# Mock usage count queries
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 5 # 5 documents use this metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = 5 # 5 documents use this metadata
# Act
result = MetadataService.get_dataset_metadatas(dataset)

View File

@@ -292,7 +292,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
assert result is api
@@ -302,7 +302,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
@@ -320,7 +320,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
mock_db_session.scalar.return_value = existing_api
args = {
"name": "New Name",
@@ -340,7 +340,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
@@ -356,7 +356,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
@@ -368,7 +368,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
@@ -394,7 +394,7 @@ class TestExternalDatasetServiceUsageAndBindings:
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@@ -406,7 +406,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@@ -419,7 +419,7 @@ class TestExternalDatasetServiceUsageAndBindings:
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
mock_db_session.scalar.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
@@ -429,7 +429,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Missing binding should result in a ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
@@ -460,7 +460,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
@@ -474,7 +474,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
@@ -488,7 +488,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
@@ -702,7 +702,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
}
# No existing dataset with same name.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
@@ -724,7 +724,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
mock_db_session.scalar.return_value = existing_dataset
args = {
"name": "Existing",
@@ -744,7 +744,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
# First call: duplicate name check not found.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None,
None, # external knowledge api lookup
]
@@ -763,8 +763,10 @@ class TestExternalDatasetServiceCreateExternalDataset:
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
mock_db_session.scalar.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
None,
Mock(spec=ExternalKnowledgeApis),
]
@@ -826,7 +828,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]
@@ -861,7 +863,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
Missing binding should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
@@ -878,7 +880,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
None,
]
@@ -901,7 +903,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]

View File

@@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
first_query = mocker.Mock()
first_query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
@@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
def test_update_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
# First query finds the template, second query (duplicate check) returns None
query_mock_1 = mocker.Mock()
query_mock_1.where.return_value.first.return_value = template
query_mock_2 = mocker.Mock()
query_mock_2.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
# First scalar finds the template, second scalar (duplicate check) returns None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
def test_update_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
@@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
duplicate = SimpleNamespace(name="dup")
query_mock = mocker.Mock()
query_mock.where.return_value.first.side_effect = [template, duplicate]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
@@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
def test_delete_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(id="tpl-1")
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
@@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 1
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 0
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
@@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
from models.workflow import Workflow
from models.dataset import Pipeline
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
@@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
# Mock db itself to avoid app context errors
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
# Improved mocking for session.query
def mock_query_side_effect(model):
m = mocker.Mock()
if model == Pipeline:
m.where.return_value.first.return_value = pipeline
elif model == Workflow:
m.where.return_value.first.return_value = workflow
elif model == PipelineCustomizedTemplate:
m.where.return_value.first.return_value = None
elif model == Dataset:
m.where.return_value.first.return_value = mocker.Mock()
else:
# For func.max cases
m.where.return_value.scalar.return_value = 5
m.where.return_value.first.return_value = mocker.Mock()
return m
mock_db.session.query.side_effect = mock_query_side_effect
# Mock get() for Pipeline and Workflow PK lookups
mock_db.session.get.side_effect = [pipeline, workflow]
# Mock scalar() for template name check (None) and max position (5)
mock_db.session.scalar.side_effect = [None, 5]
# Mock retrieve_dataset
dataset = mocker.Mock()
pipeline.retrieve_dataset.return_value = dataset
# Mock max position
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
mocker.patch(
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
return_value=5,
)
# Mock RagPipelineDslService
mock_dsl_service = mocker.Mock()
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
@@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
workflow.rag_pipeline_variables = []
# Mock queries
mock_query = mocker.Mock()
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock()
# Mock queries
mock_query = mocker.Mock()
# Log lookup, then Pipeline lookup
mock_query.where.return_value.first.side_effect = [log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
# Mock queries: Log lookup via scalar, Pipeline lookup via get
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
# Mock db aggressively
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock()
mock_db.session.scalar.return_value = mocker.Mock()
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p-1"
@@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_draft_workflow(pipeline)
@@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_published_workflow(pipeline)
@@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = []
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = []
result = rag_pipeline_service.get_recommended_plugins("all")
@@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
plugin_a = SimpleNamespace(plugin_id="plugin-a")
plugin_b = SimpleNamespace(plugin_id="plugin-b")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch(
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
@@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Document pipeline execution log not found"):
rag_pipeline_service.retry_error_document(
@@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Pipeline or workflow not found"):
@@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock()
query.where.return_value.first.return_value = document
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
query = mocker.Mock()
query.where.return_value.first.return_value = pipeline
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
with pytest.raises(ValueError, match="Pipeline workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
@@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock()
query.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
workflow = SimpleNamespace(id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, workflow]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value = query
mock_db.session.get.side_effect = [pipeline, workflow]
session_ctx = mocker.MagicMock()
session_ctx.__enter__.return_value = SimpleNamespace()
session_ctx.__exit__.return_value = False
@@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
plugin = SimpleNamespace(plugin_id="plugin-a")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
@@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.retry_error_document(
@@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Workflow not found"):
@@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
@@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
@@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
@@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

View File

@@ -173,9 +173,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@@ -188,9 +186,7 @@ class TestAccountService:
def test_authenticate_account_not_found(self, mock_db_dependencies):
"""Test authentication when account does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "email", "notfound@example.com"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test and verify exception
self._assert_exception_raised(
@@ -202,9 +198,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "email", "banned@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
@@ -214,9 +208,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = False
@@ -230,9 +222,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
# Setup smart database query mock
query_results = {("Account", "email", "pending@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@@ -422,12 +412,8 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Setup smart database query mock
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@@ -444,9 +430,7 @@ class TestAccountService:
def test_load_user_not_found(self, mock_db_dependencies):
"""Test user loading when user does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "id", "non-existent-user"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = None
# Execute test
result = AccountService.load_user("non-existent-user")
@@ -459,9 +443,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "id", "user-123"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(
@@ -476,13 +458,9 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
# Setup smart database query mock for complex scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@@ -503,13 +481,9 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock for no tenants scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant (None)
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@@ -1060,7 +1034,7 @@ class TestRegisterService:
)
# Verify rollback operations were called
mock_db_dependencies["db"].session.query.assert_called()
mock_db_dependencies["db"].session.execute.assert_called()
# ==================== Registration Tests ====================
@@ -1625,10 +1599,8 @@ class TestRegisterService:
mock_session_class.return_value.__exit__.return_value = None
mock_lookup.return_value = mock_existing_account
# Mock the db.session.query for TenantAccountJoin
mock_db_query = MagicMock()
mock_db_query.filter_by.return_value.first.return_value = None # No existing member
mock_db_dependencies["db"].session.query.return_value = mock_db_query
# Mock scalar for TenantAccountJoin lookup - no existing member
mock_db_dependencies["db"].session.scalar.return_value = None
# Mock TenantService methods
with (
@@ -1803,14 +1775,9 @@ class TestRegisterService:
}
mock_get_invitation_by_token.return_value = invitation_data
# Mock database queries - complex query mocking
mock_query1 = MagicMock()
mock_query1.where.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant lookup, execute for account+role lookup
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@@ -1842,10 +1809,8 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries - no tenant found
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = None
mock_db_dependencies["db"].session.query.return_value = mock_query
# Mock scalar for tenant lookup - not found
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@@ -1868,14 +1833,9 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant, execute for account+role
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = None # No account found
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
@@ -1901,14 +1861,9 @@ class TestRegisterService:
}
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
# Mock database queries
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
# Mock scalar for tenant, execute for account+role
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant
mock_db_dependencies["db"].session.execute.return_value.first.return_value = (mock_account, "normal")
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")

View File

@@ -23,7 +23,6 @@ import pytest
import services.app_generate_service as ags_module
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.quota_type import QuotaType
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
@@ -448,8 +447,8 @@ class TestGenerateBilling:
def test_billing_enabled_consumes_quota(self, mocker, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
reserve_mock = mocker.patch(
"services.app_generate_service.QuotaService.reserve",
consume_mock = mocker.patch(
"services.app_generate_service.QuotaType.WORKFLOW.consume",
return_value=quota_charge,
)
mocker.patch(
@@ -468,8 +467,7 @@ class TestGenerateBilling:
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
reserve_mock.assert_called_once_with(QuotaType.WORKFLOW, "tenant-id")
quota_charge.commit.assert_called_once()
consume_mock.assert_called_once_with("tenant-id")
def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch):
from services.errors.app import QuotaExceededError
@@ -477,7 +475,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
mocker.patch(
"services.app_generate_service.QuotaService.reserve",
"services.app_generate_service.QuotaType.WORKFLOW.consume",
side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1),
)
@@ -494,7 +492,7 @@ class TestGenerateBilling:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True)
quota_charge = MagicMock()
mocker.patch(
"services.app_generate_service.QuotaService.reserve",
"services.app_generate_service.QuotaType.WORKFLOW.consume",
return_value=quota_charge,
)
mocker.patch(

View File

@@ -57,7 +57,7 @@ class TestAsyncWorkflowService:
- repo: SQLAlchemyWorkflowTriggerLogRepository
- dispatcher_manager_class: QueueDispatcherManager class
- dispatcher: dispatcher instance
- quota_service: QuotaService mock
- quota_workflow: QuotaType.WORKFLOW
- get_workflow: AsyncWorkflowService._get_workflow method
- professional_task: execute_workflow_professional
- team_task: execute_workflow_team
@@ -72,7 +72,7 @@ class TestAsyncWorkflowService:
mock_repo.create.side_effect = _create_side_effect
mock_dispatcher = MagicMock()
mock_quota_service = MagicMock()
quota_workflow = MagicMock()
mock_get_workflow = MagicMock()
mock_professional_task = MagicMock()
@@ -93,8 +93,8 @@ class TestAsyncWorkflowService:
) as mock_get_workflow,
patch.object(
async_workflow_service_module,
"QuotaService",
new=mock_quota_service,
"QuotaType",
new=SimpleNamespace(WORKFLOW=quota_workflow),
),
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
@@ -107,7 +107,7 @@ class TestAsyncWorkflowService:
"repo": mock_repo,
"dispatcher_manager_class": mock_dispatcher_manager_class,
"dispatcher": mock_dispatcher,
"quota_service": mock_quota_service,
"quota_workflow": quota_workflow,
"get_workflow": mock_get_workflow,
"professional_task": mock_professional_task,
"team_task": mock_team_task,
@@ -146,9 +146,6 @@ class TestAsyncWorkflowService:
mocks["team_task"].delay.return_value = task_result
mocks["sandbox_task"].delay.return_value = task_result
quota_charge_mock = MagicMock()
mocks["quota_service"].reserve.return_value = quota_charge_mock
class DummyAccount:
def __init__(self, user_id: str):
self.id = user_id
@@ -166,8 +163,7 @@ class TestAsyncWorkflowService:
assert result.status == "queued"
assert result.queue == queue_name
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
assert session.commit.call_count == 2
created_log = mocks["repo"].create.call_args[0][0]
@@ -254,7 +250,7 @@ class TestAsyncWorkflowService:
mocks = async_workflow_trigger_mocks
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
mocks["get_workflow"].return_value = workflow
mocks["quota_service"].reserve.side_effect = QuotaExceededError(
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
feature="workflow",
tenant_id="tenant-123",
required=1,

View File

@@ -290,19 +290,9 @@ class TestBillingServiceSubscriptionInfo:
# Arrange
tenant_id = "tenant-123"
expected_response = {
"enabled": True,
"subscription": {"plan": "professional", "interval": "month", "education": False},
"members": {"size": 1, "limit": 50},
"apps": {"size": 1, "limit": 200},
"vector_space": {"size": 0.0, "limit": 20480},
"knowledge_rate_limit": {"limit": 1000},
"documents_upload_quota": {"size": 0, "limit": 1000},
"annotation_quota_limit": {"size": 0, "limit": 5000},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": 1775952000,
"subscription_plan": "professional",
"billing_cycle": "monthly",
"status": "active",
}
mock_send_request.return_value = expected_response
@@ -425,7 +415,7 @@ class TestBillingServiceUsageCalculation:
yield mock
def test_get_tenant_feature_plan_usage_info(self, mock_send_request):
"""Test retrieval of tenant feature plan usage information (legacy endpoint)."""
"""Test retrieval of tenant feature plan usage information."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}}
@@ -438,20 +428,6 @@ class TestBillingServiceUsageCalculation:
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id})
def test_get_quota_info(self, mock_send_request):
"""Test retrieval of quota info from new endpoint."""
# Arrange
tenant_id = "tenant-123"
expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}}
mock_send_request.return_value = expected_response
# Act
result = BillingService.get_quota_info(tenant_id)
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id})
def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request):
"""Test updating tenant feature usage with positive delta (adding credits)."""
# Arrange
@@ -529,118 +505,6 @@ class TestBillingServiceUsageCalculation:
)
class TestBillingServiceQuotaOperations:
"""Unit tests for quota reserve/commit/release operations."""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_quota_reserve_success(self, mock_send_request):
expected = {"reservation_id": "rid-1", "available": 99, "reserved": 1}
mock_send_request.return_value = expected
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-1", amount=1)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/reserve",
json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1},
)
def test_quota_reserve_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"}
result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1)
assert result["available"] == 99
assert isinstance(result["available"], int)
assert result["reserved"] == 1
assert isinstance(result["reserved"], int)
def test_quota_reserve_with_meta(self, mock_send_request):
mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1}
meta = {"source": "webhook"}
BillingService.quota_reserve(
tenant_id="t1", feature_key="trigger_event", request_id="req-2", amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"source": "webhook"}
def test_quota_commit_success(self, mock_send_request):
expected = {"available": 98, "reserved": 0, "refunded": 0}
mock_send_request.return_value = expected
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1
)
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/commit",
json={
"tenant_id": "t1",
"feature_key": "trigger_event",
"reservation_id": "rid-1",
"actual_amount": 1,
},
)
def test_quota_commit_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"}
result = BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1
)
assert result["available"] == 97
assert isinstance(result["available"], int)
assert result["refunded"] == 1
assert isinstance(result["refunded"], int)
def test_quota_commit_with_meta(self, mock_send_request):
mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0}
meta = {"reason": "partial"}
BillingService.quota_commit(
tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1", actual_amount=1, meta=meta
)
call_json = mock_send_request.call_args[1]["json"]
assert call_json["meta"] == {"reason": "partial"}
def test_quota_release_success(self, mock_send_request):
expected = {"available": 100, "reserved": 0, "released": 1}
mock_send_request.return_value = expected
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1")
assert result == expected
mock_send_request.assert_called_once_with(
"POST",
"/quota/release",
json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"},
)
def test_quota_release_coerces_string_to_int(self, mock_send_request):
"""Test that TypeAdapter coerces string values to int."""
mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"}
result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s")
assert result["available"] == 100
assert isinstance(result["available"], int)
assert result["released"] == 1
assert isinstance(result["released"], int)
class TestBillingServiceRateLimitEnforcement:
"""Unit tests for rate limit enforcement mechanisms.
@@ -1131,14 +995,17 @@ class TestBillingServiceEdgeCases:
yield mock
def test_get_info_empty_response(self, mock_send_request):
"""Empty response from billing API should raise ValidationError due to missing required fields."""
from pydantic import ValidationError
"""Test handling of empty billing info response."""
# Arrange
tenant_id = "tenant-empty"
mock_send_request.return_value = {}
with pytest.raises(ValidationError):
BillingService.get_info(tenant_id)
# Act
result = BillingService.get_info(tenant_id)
# Assert
assert result == {}
mock_send_request.assert_called_once()
def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request):
"""Test updating tenant feature usage with zero delta (no change)."""
@@ -1549,21 +1416,12 @@ class TestBillingServiceIntegrationScenarios:
# Step 1: Get current billing info
mock_send_request.return_value = {
"enabled": True,
"subscription": {"plan": "sandbox", "interval": "", "education": False},
"members": {"size": 0, "limit": 1},
"apps": {"size": 0, "limit": 5},
"vector_space": {"size": 0.0, "limit": 50},
"knowledge_rate_limit": {"limit": 10},
"documents_upload_quota": {"size": 0, "limit": 50},
"annotation_quota_limit": {"size": 0, "limit": 10},
"docs_processing": "standard",
"can_replace_logo": False,
"model_load_balancing_enabled": False,
"knowledge_pipeline_publish_enabled": False,
"subscription_plan": "sandbox",
"billing_cycle": "monthly",
"status": "active",
}
current_info = BillingService.get_info(tenant_id)
assert current_info["subscription"]["plan"] == "sandbox"
assert current_info["subscription_plan"] == "sandbox"
# Step 2: Get payment link for upgrade
mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"}
@@ -1677,140 +1535,3 @@ class TestBillingServiceIntegrationScenarios:
mock_send_request.return_value = {"result": "success", "activated": True}
activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student")
assert activate_result["activated"] is True
class TestBillingServiceSubscriptionInfoDataType:
"""Unit tests for data type coercion in BillingService.get_info
1. Verifies the get_info returns correct Python types for numeric fields
2. Ensure the compatibility regardless of what results the upstream billing API returns
"""
@pytest.fixture
def mock_send_request(self):
with patch.object(BillingService, "_send_request") as mock:
yield mock
@pytest.fixture
def normal_billing_response(self) -> dict:
return {
"enabled": True,
"subscription": {
"plan": "team",
"interval": "year",
"education": False,
},
"members": {"size": 10, "limit": 50},
"apps": {"size": 80, "limit": 200},
"vector_space": {"size": 5120.75, "limit": 20480},
"knowledge_rate_limit": {"limit": 1000},
"documents_upload_quota": {"size": 450, "limit": 1000},
"annotation_quota_limit": {"size": 1200, "limit": 5000},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": 1745971200,
}
@pytest.fixture
def string_billing_response(self) -> dict:
return {
"enabled": True,
"subscription": {
"plan": "team",
"interval": "year",
"education": False,
},
"members": {"size": "10", "limit": "50"},
"apps": {"size": "80", "limit": "200"},
"vector_space": {"size": "5120.75", "limit": "20480"},
"knowledge_rate_limit": {"limit": "1000"},
"documents_upload_quota": {"size": "450", "limit": "1000"},
"annotation_quota_limit": {"size": "1200", "limit": "5000"},
"docs_processing": "top-priority",
"can_replace_logo": True,
"model_load_balancing_enabled": True,
"knowledge_pipeline_publish_enabled": True,
"next_credit_reset_date": "1745971200",
}
@staticmethod
def _assert_billing_info_types(result: dict):
assert isinstance(result["enabled"], bool)
assert isinstance(result["subscription"]["plan"], str)
assert isinstance(result["subscription"]["interval"], str)
assert isinstance(result["subscription"]["education"], bool)
assert isinstance(result["members"]["size"], int)
assert isinstance(result["members"]["limit"], int)
assert isinstance(result["apps"]["size"], int)
assert isinstance(result["apps"]["limit"], int)
assert isinstance(result["vector_space"]["size"], float)
assert isinstance(result["vector_space"]["limit"], int)
assert isinstance(result["knowledge_rate_limit"]["limit"], int)
assert isinstance(result["documents_upload_quota"]["size"], int)
assert isinstance(result["documents_upload_quota"]["limit"], int)
assert isinstance(result["annotation_quota_limit"]["size"], int)
assert isinstance(result["annotation_quota_limit"]["limit"], int)
assert isinstance(result["docs_processing"], str)
assert isinstance(result["can_replace_logo"], bool)
assert isinstance(result["model_load_balancing_enabled"], bool)
assert isinstance(result["knowledge_pipeline_publish_enabled"], bool)
if "next_credit_reset_date" in result:
assert isinstance(result["next_credit_reset_date"], int)
def test_get_info_with_normal_types(self, mock_send_request, normal_billing_response):
"""When the billing API returns native numeric types, get_info should preserve them."""
mock_send_request.return_value = normal_billing_response
result = BillingService.get_info("tenant-type-test")
self._assert_billing_info_types(result)
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
def test_get_info_with_string_types(self, mock_send_request, string_billing_response):
"""When the billing API returns numeric values as strings, get_info should coerce them."""
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
self._assert_billing_info_types(result)
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": "tenant-type-test"})
def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response):
"""NotRequired fields can be absent without raising."""
del string_billing_response["next_credit_reset_date"]
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
assert "next_credit_reset_date" not in result
self._assert_billing_info_types(result)
def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response):
"""Undefined fields are silently stripped by validate_python."""
string_billing_response["new_feature"] = "something"
mock_send_request.return_value = string_billing_response
result = BillingService.get_info("tenant-type-test")
# extra fields are dropped by TypeAdapter on TypedDict
assert "new_feature" not in result
self._assert_billing_info_types(result)
def test_get_info_missing_required_field_raises(self, mock_send_request, string_billing_response):
"""Missing a required field should raise ValidationError."""
from pydantic import ValidationError
del string_billing_response["members"]
mock_send_request.return_value = string_billing_response
with pytest.raises(ValidationError):
BillingService.get_info("tenant-type-test")

View File

@@ -799,10 +799,7 @@ class TestExternalDatasetServiceGetAPI:
api_id = "api-123"
expected_api = factory.create_external_knowledge_api_mock(api_id=api_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_api
mock_db.session.scalar.return_value = expected_api
# Act
tenant_id = "tenant-123"
@@ -810,16 +807,12 @@ class TestExternalDatasetServiceGetAPI:
# Assert
assert result.id == api_id
mock_query.filter_by.assert_called_once_with(id=api_id, tenant_id=tenant_id)
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -848,10 +841,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://new.example.com", "api_key": "new-key"},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
@@ -881,10 +871,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, "user-123", api_id, args)
@@ -897,10 +884,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@@ -912,10 +896,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@@ -934,10 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
args = {"name": "New Name Only"}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
@@ -958,10 +936,7 @@ class TestExternalDatasetServiceDeleteAPI:
existing_api = factory.create_external_knowledge_api_mock(api_id=api_id, tenant_id=tenant_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
ExternalDatasetService.delete_external_knowledge_api(tenant_id, api_id)
@@ -974,10 +949,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -987,10 +959,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -1006,10 +975,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 1
mock_db.session.scalar.return_value = 1
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@@ -1024,10 +990,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 10
mock_db.session.scalar.return_value = 10
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@@ -1042,10 +1005,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 0
mock_db.session.scalar.return_value = 0
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@@ -1067,10 +1027,7 @@ class TestExternalDatasetServiceGetBinding:
expected_binding = factory.create_external_knowledge_binding_mock(tenant_id=tenant_id, dataset_id=dataset_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_binding
mock_db.session.scalar.return_value = expected_binding
# Act
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id(tenant_id, dataset_id)
@@ -1083,10 +1040,7 @@ class TestExternalDatasetServiceGetBinding:
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
"""Test error when binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@@ -1113,10 +1067,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"param1": "value1", "param2": "value2"}
@@ -1134,10 +1085,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {}
@@ -1149,10 +1097,7 @@ class TestExternalDatasetServiceDocumentValidate:
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
"""Test validation fails when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@@ -1165,10 +1110,7 @@ class TestExternalDatasetServiceDocumentValidate:
settings = {}
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
# Act & Assert - should not raise
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
@@ -1186,10 +1128,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"required_param": "value"}
@@ -1498,24 +1437,7 @@ class TestExternalDatasetServiceCreateDataset:
api = factory.create_external_knowledge_api_mock(api_id="api-123")
# Mock database queries
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
# Act
result = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
@@ -1534,10 +1456,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_dataset
mock_db.session.scalar.return_value = existing_dataset
args = {"name": "Duplicate Dataset"}
@@ -1549,23 +1468,7 @@ class TestExternalDatasetServiceCreateDataset:
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
"""Test error when external knowledge API is not found."""
# Arrange
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = None
mock_db.session.scalar.side_effect = [None, None]
args = {"name": "Test Dataset", "external_knowledge_api_id": "nonexistent-api"}
@@ -1579,23 +1482,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_api_id": "api-123"}
@@ -1609,23 +1496,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_id": "knowledge-123"}
@@ -1651,23 +1522,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@@ -1695,10 +1550,7 @@ class TestExternalDatasetServiceFetchRetrieval:
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
"""Test error when external knowledge binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@@ -1712,23 +1564,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@@ -1751,23 +1587,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@@ -1799,23 +1619,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 500
@@ -1856,23 +1660,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = status_code
@@ -1891,23 +1679,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 503

40
api/uv.lock generated
View File

@@ -53,23 +53,6 @@ dependencies = [
]
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/7e/cb94129302d78c46662b47f9897d642fd0b33bdfef4b73b20c6ced35aa4c/aiohttp-3.13.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8ea0c64d1bcbf201b285c2246c51a0c035ba3bbd306640007bc5844a3b4658c1", size = 760027, upload-time = "2026-03-28T17:15:33.022Z" },
{ url = "https://files.pythonhosted.org/packages/5e/cd/2db3c9397c3bd24216b203dd739945b04f8b87bb036c640da7ddb63c75ef/aiohttp-3.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6f742e1fa45c0ed522b00ede565e18f97e4cf8d1883a712ac42d0339dfb0cce7", size = 508325, upload-time = "2026-03-28T17:15:34.714Z" },
{ url = "https://files.pythonhosted.org/packages/36/a3/d28b2722ec13107f2e37a86b8a169897308bab6a3b9e071ecead9d67bd9b/aiohttp-3.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dcfb50ee25b3b7a1222a9123be1f9f89e56e67636b561441f0b304e25aaef8f", size = 502402, upload-time = "2026-03-28T17:15:36.409Z" },
{ url = "https://files.pythonhosted.org/packages/fa/d6/acd47b5f17c4430e555590990a4746efbcb2079909bb865516892bf85f37/aiohttp-3.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3262386c4ff370849863ea93b9ea60fd59c6cf56bf8f93beac625cf4d677c04d", size = 1771224, upload-time = "2026-03-28T17:15:38.223Z" },
{ url = "https://files.pythonhosted.org/packages/98/af/af6e20113ba6a48fd1cd9e5832c4851e7613ef50c7619acdaee6ec5f1aff/aiohttp-3.13.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:473bb5aa4218dd254e9ae4834f20e31f5a0083064ac0136a01a62ddbae2eaa42", size = 1731530, upload-time = "2026-03-28T17:15:39.988Z" },
{ url = "https://files.pythonhosted.org/packages/81/16/78a2f5d9c124ad05d5ce59a9af94214b6466c3491a25fb70760e98e9f762/aiohttp-3.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e56423766399b4c77b965f6aaab6c9546617b8994a956821cc507d00b91d978c", size = 1827925, upload-time = "2026-03-28T17:15:41.944Z" },
{ url = "https://files.pythonhosted.org/packages/2a/1f/79acf0974ced805e0e70027389fccbb7d728e6f30fcac725fb1071e63075/aiohttp-3.13.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8af249343fafd5ad90366a16d230fc265cf1149f26075dc9fe93cfd7c7173942", size = 1923579, upload-time = "2026-03-28T17:15:44.071Z" },
{ url = "https://files.pythonhosted.org/packages/af/53/29f9e2054ea6900413f3b4c3eb9d8331f60678ec855f13ba8714c47fd48d/aiohttp-3.13.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc0a5cf4f10ef5a2c94fdde488734b582a3a7a000b131263e27c9295bd682d9", size = 1767655, upload-time = "2026-03-28T17:15:45.911Z" },
{ url = "https://files.pythonhosted.org/packages/f3/57/462fe1d3da08109ba4aa8590e7aed57c059af2a7e80ec21f4bac5cfe1094/aiohttp-3.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5c7ff1028e3c9fc5123a865ce17df1cb6424d180c503b8517afbe89aa566e6be", size = 1630439, upload-time = "2026-03-28T17:15:48.11Z" },
{ url = "https://files.pythonhosted.org/packages/d7/4b/4813344aacdb8127263e3eec343d24e973421143826364fa9fc847f6283f/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ba5cf98b5dcb9bddd857da6713a503fa6d341043258ca823f0f5ab7ab4a94ee8", size = 1745557, upload-time = "2026-03-28T17:15:50.13Z" },
{ url = "https://files.pythonhosted.org/packages/d4/01/1ef1adae1454341ec50a789f03cfafe4c4ac9c003f6a64515ecd32fe4210/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d85965d3ba21ee4999e83e992fecb86c4614d6920e40705501c0a1f80a583c12", size = 1741796, upload-time = "2026-03-28T17:15:52.351Z" },
{ url = "https://files.pythonhosted.org/packages/22/04/8cdd99af988d2aa6922714d957d21383c559835cbd43fbf5a47ddf2e0f05/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:49f0b18a9b05d79f6f37ddd567695943fcefb834ef480f17a4211987302b2dc7", size = 1805312, upload-time = "2026-03-28T17:15:54.407Z" },
{ url = "https://files.pythonhosted.org/packages/fb/7f/b48d5577338d4b25bbdbae35c75dbfd0493cb8886dc586fbfb2e90862239/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7f78cb080c86fbf765920e5f1ef35af3f24ec4314d6675d0a21eaf41f6f2679c", size = 1621751, upload-time = "2026-03-28T17:15:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/bc/89/4eecad8c1858e6d0893c05929e22343e0ebe3aec29a8a399c65c3cc38311/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:67a3ec705534a614b68bbf1c70efa777a21c3da3895d1c44510a41f5a7ae0453", size = 1826073, upload-time = "2026-03-28T17:15:58.489Z" },
{ url = "https://files.pythonhosted.org/packages/f5/5c/9dc8293ed31b46c39c9c513ac7ca152b3c3d38e0ea111a530ad12001b827/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6630ec917e85c5356b2295744c8a97d40f007f96a1c76bf1928dc2e27465393", size = 1760083, upload-time = "2026-03-28T17:16:00.677Z" },
{ url = "https://files.pythonhosted.org/packages/1e/19/8bbf6a4994205d96831f97b7d21a0feed120136e6267b5b22d229c6dc4dc/aiohttp-3.13.4-cp311-cp311-win32.whl", hash = "sha256:54049021bc626f53a5394c29e8c444f726ee5a14b6e89e0ad118315b1f90f5e3", size = 439690, upload-time = "2026-03-28T17:16:02.902Z" },
{ url = "https://files.pythonhosted.org/packages/0c/f5/ac409ecd1007528d15c3e8c3a57d34f334c70d76cfb7128a28cffdebd4c1/aiohttp-3.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:c033f2bc964156030772d31cbf7e5defea181238ce1f87b9455b786de7d30145", size = 463824, upload-time = "2026-03-28T17:16:05.058Z" },
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" },
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" },
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" },
@@ -1586,7 +1569,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.19.1" },
{ name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.57.1" },
{ name = "pyrefly", specifier = ">=0.59.1" },
{ name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" },
@@ -4839,18 +4822,19 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.57.1"
version = "0.59.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
]
[[package]]

View File

@@ -5,7 +5,6 @@
"prepare": "vp config"
},
"devDependencies": {
"taze": "catalog:",
"vite-plus": "catalog:"
},
"engines": {

View File

@@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.33317 3.33333H7.33317V12.6667H5.33317V14H10.6665V12.6667H8.6665V3.33333H10.6665V2H5.33317V3.33333ZM1.33317 4.66667C0.964984 4.66667 0.666504 4.96515 0.666504 5.33333V10.6667C0.666504 11.0349 0.964984 11.3333 1.33317 11.3333H5.33317V10H1.99984V6H5.33317V4.66667H1.33317ZM10.6665 6H13.9998V10H10.6665V11.3333H14.6665C15.0347 11.3333 15.3332 11.0349 15.3332 10.6667V5.33333C15.3332 4.96515 15.0347 4.66667 14.6665 4.66667H10.6665V6Z" fill="#354052"/>
</svg>

After

Width:  |  Height:  |  Size: 563 B

3079
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,9 @@
catalogMode: prefer
trustPolicy: no-downgrade
minimumReleaseAge: 2880
trustPolicyExclude:
- chokidar@4.0.3
- reselect@5.1.1
- semver@6.3.1
blockExoticSubdeps: true
strictDepBuilds: true
allowBuilds:
@@ -23,7 +27,7 @@ overrides:
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
assert: npm:@nolyfill/assert@^1.0.26
brace-expansion@<2.0.2: 2.0.2
brace-expansion@>=2.0.0 <2.0.3: 2.0.3
canvas: ^3.2.2
devalue@<5.3.2: 5.3.2
dompurify@>=3.1.3 <=3.3.1: 3.3.2
@@ -37,6 +41,8 @@ overrides:
is-generator-function: npm:@nolyfill/is-generator-function@^1.0.44
is-typed-array: npm:@nolyfill/is-typed-array@^1.0.44
isarray: npm:@nolyfill/isarray@^1.0.44
lodash@>=4.0.0 <= 4.17.23: 4.18.0
lodash-es@>=4.0.0 <= 4.17.23: 4.18.0
object.assign: npm:@nolyfill/object.assign@^1.0.44
object.entries: npm:@nolyfill/object.entries@^1.0.44
object.fromentries: npm:@nolyfill/object.fromentries@^1.0.44
@@ -64,15 +70,15 @@ overrides:
tar@<=7.5.10: 7.5.11
typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44
undici@>=7.0.0 <7.24.0: 7.24.0
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
yaml@>=2.0.0 <2.8.3: 2.8.3
yauzl@<3.2.1: 3.2.1
catalog:
"@amplitude/analytics-browser": 2.38.0
"@amplitude/plugin-session-replay-browser": 1.27.5
"@antfu/eslint-config": 7.7.3
"@amplitude/analytics-browser": 2.38.1
"@amplitude/plugin-session-replay-browser": 1.27.6
"@antfu/eslint-config": 8.0.0
"@base-ui/react": 1.3.0
"@chromatic-com/storybook": 5.1.1
"@cucumber/cucumber": 12.7.0
@@ -84,7 +90,7 @@ catalog:
"@formatjs/intl-localematcher": 0.8.2
"@headlessui/react": 2.2.9
"@heroicons/react": 2.2.0
"@hono/node-server": 1.19.11
"@hono/node-server": 1.19.12
"@iconify-json/heroicons": 1.2.3
"@iconify-json/ri": 1.2.10
"@lexical/code": 0.42.0
@@ -98,34 +104,35 @@ catalog:
"@mdx-js/react": 3.1.1
"@mdx-js/rollup": 3.1.1
"@monaco-editor/react": 4.7.0
"@next/eslint-plugin-next": 16.2.1
"@next/mdx": 16.2.1
"@next/eslint-plugin-next": 16.2.2
"@next/mdx": 16.2.2
"@orpc/client": 1.13.13
"@orpc/contract": 1.13.13
"@orpc/openapi-client": 1.13.13
"@orpc/tanstack-query": 1.13.13
"@playwright/test": 1.58.2
"@playwright/test": 1.59.1
"@remixicon/react": 4.9.0
"@rgrove/parse-xml": 4.2.0
"@sentry/react": 10.46.0
"@storybook/addon-docs": 10.3.3
"@storybook/addon-links": 10.3.3
"@storybook/addon-onboarding": 10.3.3
"@storybook/addon-themes": 10.3.3
"@storybook/nextjs-vite": 10.3.3
"@storybook/react": 10.3.3
"@sentry/react": 10.47.0
"@storybook/addon-docs": 10.3.4
"@storybook/addon-links": 10.3.4
"@storybook/addon-onboarding": 10.3.4
"@storybook/addon-themes": 10.3.4
"@storybook/nextjs-vite": 10.3.4
"@storybook/react": 10.3.4
"@streamdown/math": 1.0.2
"@svgdotjs/svg.js": 3.2.5
"@t3-oss/env-nextjs": 0.13.11
"@tailwindcss/postcss": 4.2.2
"@tailwindcss/typography": 0.5.19
"@tailwindcss/vite": 4.2.2
"@tanstack/eslint-plugin-query": 5.95.2
"@tanstack/react-devtools": 0.10.0
"@tanstack/react-form": 1.28.5
"@tanstack/react-form-devtools": 0.2.19
"@tanstack/react-query": 5.95.2
"@tanstack/react-query-devtools": 5.95.2
"@tanstack/eslint-plugin-query": 5.96.1
"@tanstack/react-devtools": 0.10.1
"@tanstack/react-form": 1.28.6
"@tanstack/react-form-devtools": 0.2.20
"@tanstack/react-query": 5.96.1
"@tanstack/react-query-devtools": 5.96.1
"@tanstack/react-virtual": 3.13.23
"@testing-library/dom": 10.4.1
"@testing-library/jest-dom": 6.9.1
"@testing-library/react": 16.3.2
@@ -141,15 +148,13 @@ catalog:
"@types/qs": 6.15.0
"@types/react": 19.2.14
"@types/react-dom": 19.2.3
"@types/react-syntax-highlighter": 15.5.13
"@types/react-window": 1.8.8
"@types/sortablejs": 1.15.9
"@typescript-eslint/eslint-plugin": 8.57.2
"@typescript-eslint/parser": 8.57.2
"@typescript/native-preview": 7.0.0-dev.20260329.1
"@typescript-eslint/eslint-plugin": 8.58.0
"@typescript-eslint/parser": 8.58.0
"@typescript/native-preview": 7.0.0-dev.20260401.1
"@vitejs/plugin-react": 6.0.1
"@vitejs/plugin-rsc": 0.5.21
"@vitest/coverage-v8": 4.1.1
"@vitest/coverage-v8": 4.1.2
abcjs: 6.6.2
agentation: 3.0.2
ahooks: 3.9.7
@@ -157,7 +162,7 @@ catalog:
class-variance-authority: 0.7.1
clsx: 2.1.1
cmdk: 1.1.1
code-inspector-plugin: 1.4.5
code-inspector-plugin: 1.5.1
copy-to-clipboard: 3.3.3
cron-parser: 5.5.0
dayjs: 1.11.20
@@ -174,19 +179,19 @@ catalog:
eslint-markdown: 0.6.0
eslint-plugin-better-tailwindcss: 4.3.2
eslint-plugin-hyoban: 0.14.1
eslint-plugin-markdown-preferences: 0.40.3
eslint-plugin-markdown-preferences: 0.41.0
eslint-plugin-no-barrel-files: 1.2.2
eslint-plugin-react-hooks: 7.0.1
eslint-plugin-react-refresh: 0.5.2
eslint-plugin-sonarjs: 4.0.2
eslint-plugin-storybook: 10.3.3
eslint-plugin-storybook: 10.3.4
fast-deep-equal: 3.1.3
foxact: 0.3.0
happy-dom: 20.8.9
hono: 4.12.9
hast-util-to-jsx-runtime: 2.3.6
hono: 4.12.10
html-entities: 2.6.0
html-to-image: 1.11.13
i18next: 25.10.10
i18next: 26.0.3
i18next-resources-to-backend: 1.2.1
iconify-import-svg: 0.1.2
immer: 11.1.4
@@ -196,15 +201,15 @@ catalog:
js-yaml: 4.1.1
jsonschema: 1.5.0
katex: 0.16.44
knip: 6.1.0
knip: 6.2.0
ky: 1.14.3
lamejs: 1.2.1
lexical: 0.42.0
mermaid: 11.13.0
mermaid: 11.14.0
mime: 4.1.0
mitt: 3.0.1
negotiator: 1.0.0
next: 16.2.1
next: 16.2.2
next-themes: 0.4.6
nuqs: 2.8.9
pinyin-pro: 3.28.0
@@ -217,42 +222,39 @@ catalog:
react-dom: 19.2.4
react-easy-crop: 5.5.7
react-hotkeys-hook: 5.2.4
react-i18next: 16.6.6
react-i18next: 17.0.2
react-multi-email: 1.0.25
react-papaparse: 4.4.0
react-pdf-highlighter: 8.0.0-rc.0
react-server-dom-webpack: 19.2.4
react-sortablejs: 6.1.4
react-syntax-highlighter: 15.6.6
react-textarea-autosize: 8.5.9
react-window: 1.8.11
reactflow: 11.11.4
remark-breaks: 4.0.0
remark-directive: 4.0.0
sass: 1.98.0
scheduler: 0.27.0
sharp: 0.34.5
shiki: 4.0.2
sortablejs: 1.15.7
std-semver: 1.0.8
storybook: 10.3.3
storybook: 10.3.4
streamdown: 2.5.0
string-ts: 2.3.1
tailwind-merge: 3.5.0
tailwindcss: 4.2.2
taze: 19.10.0
tldts: 7.0.27
tsup: ^8.5.1
tsdown: 0.21.7
tsx: 4.21.0
typescript: 5.9.3
typescript: 6.0.2
uglify-js: 3.19.3
unist-util-visit: 5.1.0
use-context-selector: 2.0.0
uuid: 13.0.0
vinext: 0.0.38
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
vinext: 0.0.39
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.14
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
vite-plus: 0.1.15
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
vitest-canvas-mock: 1.1.4
zod: 4.3.6
zundo: 2.3.0

View File

@@ -45,12 +45,12 @@
"homepage": "https://dify.ai",
"license": "MIT",
"scripts": {
"build": "tsup",
"build": "vp pack",
"lint": "eslint",
"lint:fix": "eslint --fix",
"type-check": "tsc -p tsconfig.json --noEmit",
"test": "vitest run",
"test:coverage": "vitest run --coverage",
"test": "vp test",
"test:coverage": "vp test --coverage",
"publish:check": "./scripts/publish.sh --dry-run",
"publish:npm": "./scripts/publish.sh"
},
@@ -61,8 +61,8 @@
"@typescript-eslint/parser": "catalog:",
"@vitest/coverage-v8": "catalog:",
"eslint": "catalog:",
"tsup": "catalog:",
"typescript": "catalog:",
"vite-plus": "catalog:",
"vitest": "catalog:"
}
}

View File

@@ -11,7 +11,8 @@
"strict": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"skipLibCheck": true
"skipLibCheck": true,
"types": ["node"]
},
"include": ["src/**/*.ts", "tests/**/*.ts"]
}

View File

@@ -1,12 +0,0 @@
import { defineConfig } from "tsup";
export default defineConfig({
entry: ["src/index.ts"],
format: ["esm"],
dts: true,
clean: true,
sourcemap: true,
splitting: false,
treeshake: true,
outDir: "dist",
});

View File

@@ -1,6 +1,17 @@
import { defineConfig } from "vitest/config";
import { defineConfig } from "vite-plus";
export default defineConfig({
pack: {
entry: ["src/index.ts"],
format: ["esm"],
dts: true,
clean: true,
sourcemap: true,
// splitting: false,
treeshake: true,
outDir: "dist",
target: false,
},
test: {
environment: "node",
include: ["**/*.test.ts"],

View File

@@ -1,15 +0,0 @@
import { defineConfig } from 'taze'
export default defineConfig({
exclude: [
// We are going to replace these
'react-syntax-highlighter',
'react-window',
'@types/react-window',
// We can not upgrade these yet
'typescript',
],
maturityPeriod: 2,
})

View File

@@ -0,0 +1,36 @@
import { vi } from 'vitest'
const mockVirtualizer = ({
count,
estimateSize,
}: {
count: number
estimateSize?: (index: number) => number
}) => {
const getSize = (index: number) => estimateSize?.(index) ?? 0
return {
getTotalSize: () => Array.from({ length: count }).reduce<number>((total, _, index) => total + getSize(index), 0),
getVirtualItems: () => {
let start = 0
return Array.from({ length: count }).map((_, index) => {
const size = getSize(index)
const virtualItem = {
end: start + size,
index,
key: index,
size,
start,
}
start += size
return virtualItem
})
},
measureElement: vi.fn(),
scrollToIndex: vi.fn(),
}
}
export { mockVirtualizer as useVirtualizer }

View File

@@ -0,0 +1,11 @@
import Evaluation from '@/app/components/evaluation'
const Page = async (props: {
params: Promise<{ appId: string }>
}) => {
const { appId } = await props.params
return <Evaluation resourceType="workflow" resourceId={appId} />
}
export default Page

View File

@@ -7,6 +7,8 @@ import {
RiDashboard2Line,
RiFileList3Fill,
RiFileList3Line,
RiFlaskFill,
RiFlaskLine,
RiTerminalBoxFill,
RiTerminalBoxLine,
RiTerminalWindowFill,
@@ -35,7 +37,7 @@ const TagManagementModal = dynamic(() => import('@/app/components/base/tag-manag
ssr: false,
})
export type IAppDetailLayoutProps = {
type IAppDetailLayoutProps = {
children: React.ReactNode
appId: string
}
@@ -67,40 +69,47 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
}>>([])
const getNavigationConfig = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: AppModeEnum) => {
const navConfig = [
...(isCurrentWorkspaceEditor
? [{
name: t('appMenus.promptEng', { ns: 'common' }),
href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`,
icon: RiTerminalWindowLine,
selectedIcon: RiTerminalWindowFill,
}]
: []
),
{
name: t('appMenus.apiAccess', { ns: 'common' }),
href: `/app/${appId}/develop`,
icon: RiTerminalBoxLine,
selectedIcon: RiTerminalBoxFill,
},
...(isCurrentWorkspaceEditor
? [{
name: mode !== AppModeEnum.WORKFLOW
? t('appMenus.logAndAnn', { ns: 'common' })
: t('appMenus.logs', { ns: 'common' }),
href: `/app/${appId}/logs`,
icon: RiFileList3Line,
selectedIcon: RiFileList3Fill,
}]
: []
),
{
name: t('appMenus.overview', { ns: 'common' }),
href: `/app/${appId}/overview`,
icon: RiDashboard2Line,
selectedIcon: RiDashboard2Fill,
},
]
const navConfig = []
if (isCurrentWorkspaceEditor) {
navConfig.push({
name: t('appMenus.promptEng', { ns: 'common' }),
href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`,
icon: RiTerminalWindowLine,
selectedIcon: RiTerminalWindowFill,
})
navConfig.push({
name: t('appMenus.evaluation', { ns: 'common' }),
href: `/app/${appId}/evaluation`,
icon: RiFlaskLine,
selectedIcon: RiFlaskFill,
})
}
navConfig.push({
name: t('appMenus.apiAccess', { ns: 'common' }),
href: `/app/${appId}/develop`,
icon: RiTerminalBoxLine,
selectedIcon: RiTerminalBoxFill,
})
if (isCurrentWorkspaceEditor) {
navConfig.push({
name: mode !== AppModeEnum.WORKFLOW
? t('appMenus.logAndAnn', { ns: 'common' })
: t('appMenus.logs', { ns: 'common' }),
href: `/app/${appId}/logs`,
icon: RiFileList3Line,
selectedIcon: RiFileList3Fill,
})
}
navConfig.push({
name: t('appMenus.overview', { ns: 'common' }),
href: `/app/${appId}/overview`,
icon: RiDashboard2Line,
selectedIcon: RiDashboard2Fill,
})
return navConfig
}, [t])

View File

@@ -25,7 +25,7 @@ import { useAppWorkflow } from '@/service/use-workflow'
import { AppModeEnum } from '@/types/app'
import { asyncRunSafe } from '@/utils'
export type ICardViewProps = {
type ICardViewProps = {
appId: string
isInPanel?: boolean
className?: string

View File

@@ -27,7 +27,7 @@ const TIME_PERIOD_MAPPING: { value: number, name: TimePeriodName }[] = [
const queryDateFormat = 'YYYY-MM-DD HH:mm'
export type IChartViewProps = {
type IChartViewProps = {
appId: string
headerRight: React.ReactNode
}

View File

@@ -1,5 +1,3 @@
@reference "../../../../styles/globals.css";
.app {
flex-grow: 1;
height: 0;

View File

@@ -1,11 +0,0 @@
@reference "../../../../../styles/globals.css";
.logTable td {
padding: 7px 8px;
box-sizing: border-box;
max-width: 200px;
}
.pagination li {
list-style: none;
}

View File

@@ -0,0 +1,11 @@
import Evaluation from '@/app/components/evaluation'
const Page = async (props: {
params: Promise<{ datasetId: string }>
}) => {
const { datasetId } = await props.params
return <Evaluation resourceType="pipeline" resourceId={datasetId} />
}
export default Page

View File

@@ -6,6 +6,8 @@ import {
RiEqualizer2Line,
RiFileTextFill,
RiFileTextLine,
RiFlaskFill,
RiFlaskLine,
RiFocus2Fill,
RiFocus2Line,
} from '@remixicon/react'
@@ -26,7 +28,7 @@ import { usePathname } from '@/next/navigation'
import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset'
import { cn } from '@/utils/classnames'
export type IAppDetailLayoutProps = {
type IAppDetailLayoutProps = {
children: React.ReactNode
datasetId: string
}
@@ -86,20 +88,30 @@ const DatasetDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
]
if (datasetRes?.provider !== 'external') {
baseNavigation.unshift({
name: t('datasetMenus.pipeline', { ns: 'common' }),
href: `/datasets/${datasetId}/pipeline`,
icon: PipelineLine as RemixiconComponentType,
selectedIcon: PipelineFill as RemixiconComponentType,
disabled: false,
})
baseNavigation.unshift({
name: t('datasetMenus.documents', { ns: 'common' }),
href: `/datasets/${datasetId}/documents`,
icon: RiFileTextLine,
selectedIcon: RiFileTextFill,
disabled: isButtonDisabledWithPipeline,
})
return [
{
name: t('datasetMenus.documents', { ns: 'common' }),
href: `/datasets/${datasetId}/documents`,
icon: RiFileTextLine,
selectedIcon: RiFileTextFill,
disabled: isButtonDisabledWithPipeline,
},
{
name: t('datasetMenus.pipeline', { ns: 'common' }),
href: `/datasets/${datasetId}/pipeline`,
icon: PipelineLine as RemixiconComponentType,
selectedIcon: PipelineFill as RemixiconComponentType,
disabled: false,
},
{
name: t('datasetMenus.evaluation', { ns: 'common' }),
href: `/datasets/${datasetId}/evaluation`,
icon: RiFlaskLine,
selectedIcon: RiFlaskFill,
disabled: false,
},
...baseNavigation,
]
}
return baseNavigation

View File

@@ -6,7 +6,7 @@ import Loading from '@/app/components/base/loading'
import { useAppContext } from '@/context/app-context'
import { usePathname, useRouter } from '@/next/navigation'
const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const
const datasetOperatorRedirectRoutes = ['/apps', '/app', '/snippets', '/explore', '/tools'] as const
const isPathUnderRoute = (pathname: string, route: string) => pathname === route || pathname.startsWith(`${route}/`)

View File

@@ -0,0 +1,11 @@
import SnippetEvaluationPage from '@/app/components/snippets/snippet-evaluation-page'
const Page = async (props: {
params: Promise<{ snippetId: string }>
}) => {
const { snippetId } = await props.params
return <SnippetEvaluationPage snippetId={snippetId} />
}
export default Page

View File

@@ -0,0 +1,11 @@
import SnippetPage from '@/app/components/snippets'
const Page = async (props: {
params: Promise<{ snippetId: string }>
}) => {
const { snippetId } = await props.params
return <SnippetPage snippetId={snippetId} />
}
export default Page

View File

@@ -0,0 +1,21 @@
import Page from './page'
const mockRedirect = vi.fn()
vi.mock('next/navigation', () => ({
redirect: (path: string) => mockRedirect(path),
}))
describe('snippet detail redirect page', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should redirect legacy snippet detail routes to orchestrate', async () => {
await Page({
params: Promise.resolve({ snippetId: 'snippet-1' }),
})
expect(mockRedirect).toHaveBeenCalledWith('/snippets/snippet-1/orchestrate')
})
})

View File

@@ -0,0 +1,11 @@
import { redirect } from 'next/navigation'
const Page = async (props: {
params: Promise<{ snippetId: string }>
}) => {
const { snippetId } = await props.params
redirect(`/snippets/${snippetId}/orchestrate`)
}
export default Page

View File

@@ -0,0 +1,7 @@
import Apps from '@/app/components/apps'
const SnippetsPage = () => {
return <Apps pageType="snippets" />
}
export default SnippetsPage

View File

@@ -13,10 +13,6 @@ import { useProviderContext } from '@/context/provider-context'
import { useRouter } from '@/next/navigation'
import { useLogout, useUserProfile } from '@/service/use-common'
export type IAppSelector = {
isMobile: boolean
}
export default function AppSelector() {
const router = useRouter()
const { t } = useTranslation()

View File

@@ -165,6 +165,21 @@ describe('AppDetailNav', () => {
)
expect(screen.queryByTestId('extra-info')).not.toBeInTheDocument()
})
it('should render custom header and navigation when provided', () => {
render(
<AppDetailNav
navigation={navigation}
renderHeader={mode => <div data-testid="custom-header" data-mode={mode} />}
renderNavigation={mode => <div data-testid="custom-navigation" data-mode={mode} />}
/>,
)
expect(screen.getByTestId('custom-header')).toHaveAttribute('data-mode', 'expand')
expect(screen.getByTestId('custom-navigation')).toHaveAttribute('data-mode', 'expand')
expect(screen.queryByTestId('app-info')).not.toBeInTheDocument()
expect(screen.queryByTestId('nav-link-Overview')).not.toBeInTheDocument()
})
})
describe('Workflow canvas mode', () => {

View File

@@ -2,7 +2,7 @@ import type { App, AppSSO } from '@/types/app'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { AppModeEnum } from '@/types/app'
import { AppModeEnum, AppTypeEnum } from '@/types/app'
import AppInfoDetailPanel from '../app-info-detail-panel'
vi.mock('../../../base/app-icon', () => ({
@@ -135,6 +135,17 @@ describe('AppInfoDetailPanel', () => {
expect(cardView).toHaveAttribute('data-app-id', 'app-1')
})
it('should not render CardView when app type is evaluation', () => {
render(
<AppInfoDetailPanel
{...defaultProps}
appDetail={createAppDetail({ type: AppTypeEnum.EVALUATION })}
/>,
)
expect(screen.queryByTestId('card-view')).not.toBeInTheDocument()
})
it('should render app icon with large size', () => {
render(<AppInfoDetailPanel {...defaultProps} />)
const icon = screen.getByTestId('app-icon')

View File

@@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next'
import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view'
import Button from '@/app/components/base/button'
import ContentDialog from '@/app/components/base/content-dialog'
import { AppModeEnum } from '@/types/app'
import { AppModeEnum, AppTypeEnum } from '@/types/app'
import AppIcon from '../../base/app-icon'
import { getAppModeLabel } from './app-mode-labels'
import AppOperations from './app-operations'
@@ -97,7 +97,7 @@ const AppInfoDetailPanel = ({
<ContentDialog
show={show}
onClose={onClose}
className="absolute bottom-2 left-2 top-2 flex w-[420px] flex-col rounded-2xl p-0!"
className="absolute top-2 bottom-2 left-2 flex w-[420px] flex-col rounded-2xl p-0!"
>
<div className="flex shrink-0 flex-col items-start justify-center gap-3 self-stretch p-4">
<div className="flex items-center gap-3 self-stretch">
@@ -109,14 +109,14 @@ const AppInfoDetailPanel = ({
imageUrl={appDetail.icon_url}
/>
<div className="flex flex-1 flex-col items-start justify-center overflow-hidden">
<div className="w-full truncate text-text-secondary system-md-semibold">{appDetail.name}</div>
<div className="text-text-tertiary system-2xs-medium-uppercase">
<div className="w-full truncate system-md-semibold text-text-secondary">{appDetail.name}</div>
<div className="system-2xs-medium-uppercase text-text-tertiary">
{getAppModeLabel(appDetail.mode, t)}
</div>
</div>
</div>
{appDetail.description && (
<div className="overflow-wrap-anywhere max-h-[105px] w-full max-w-full overflow-y-auto whitespace-normal wrap-break-word text-text-tertiary system-xs-regular">
<div className="overflow-wrap-anywhere max-h-[105px] w-full max-w-full overflow-y-auto system-xs-regular wrap-break-word whitespace-normal text-text-tertiary">
{appDetail.description}
</div>
)}
@@ -126,11 +126,13 @@ const AppInfoDetailPanel = ({
secondaryOperations={secondaryOperations}
/>
</div>
<CardView
appId={appDetail.id}
isInPanel={true}
className="flex flex-1 flex-col gap-2 overflow-auto px-2 py-1"
/>
{appDetail.type !== AppTypeEnum.EVALUATION && (
<CardView
appId={appDetail.id}
isInPanel={true}
className="flex flex-1 flex-col gap-2 overflow-auto px-2 py-1"
/>
)}
{switchOperation && (
<div className="flex min-h-fit shrink-0 flex-col items-start justify-center gap-3 self-stretch pb-2">
<Button
@@ -140,7 +142,7 @@ const AppInfoDetailPanel = ({
onClick={switchOperation.onClick}
>
{switchOperation.icon}
<span className="text-text-tertiary system-sm-medium">{switchOperation.title}</span>
<span className="system-sm-medium text-text-tertiary">{switchOperation.title}</span>
</Button>
</div>
)}

View File

@@ -5,7 +5,7 @@ import AppInfoModals from './app-info-modals'
import AppInfoTrigger from './app-info-trigger'
import { useAppInfoActions } from './use-app-info-actions'
export type IAppInfoProps = {
type IAppInfoProps = {
expand: boolean
onlyShowDetail?: boolean
openState?: boolean

View File

@@ -7,7 +7,7 @@ import {
import Tooltip from '@/app/components/base/tooltip'
import AppIcon from '../base/app-icon'
export type IAppBasicProps = {
type IAppBasicProps = {
iconType?: 'app' | 'api' | 'dataset' | 'webapp' | 'notion'
icon?: string
icon_background?: string | null

View File

@@ -17,7 +17,7 @@ import DatasetSidebarDropdown from './dataset-sidebar-dropdown'
import NavLink from './nav-link'
import ToggleButton from './toggle-button'
export type IAppDetailNavProps = {
type IAppDetailNavProps = {
iconType?: 'app' | 'dataset'
navigation: Array<{
name: string
@@ -27,12 +27,16 @@ export type IAppDetailNavProps = {
disabled?: boolean
}>
extraInfo?: (modeState: string) => React.ReactNode
renderHeader?: (modeState: string) => React.ReactNode
renderNavigation?: (modeState: string) => React.ReactNode
}
const AppDetailNav = ({
navigation,
extraInfo,
iconType = 'app',
renderHeader,
renderNavigation,
}: IAppDetailNavProps) => {
const { appSidebarExpand, setAppSidebarExpand } = useAppStore(useShallow(state => ({
appSidebarExpand: state.appSidebarExpand,
@@ -104,10 +108,11 @@ const AppDetailNav = ({
expand ? 'p-2' : 'p-1',
)}
>
{iconType === 'app' && (
{renderHeader?.(appSidebarExpand)}
{!renderHeader && iconType === 'app' && (
<AppInfo expand={expand} />
)}
{iconType !== 'app' && (
{!renderHeader && iconType !== 'app' && (
<DatasetInfo expand={expand} />
)}
</div>
@@ -136,7 +141,8 @@ const AppDetailNav = ({
expand ? 'px-3 py-2' : 'p-3',
)}
>
{navigation.map((item, index) => {
{renderNavigation?.(appSidebarExpand)}
{!renderNavigation && navigation.map((item, index) => {
return (
<NavLink
key={index}

View File

@@ -262,4 +262,20 @@ describe('NavLink Animation and Layout Issues', () => {
expect(iconWrapper).toHaveClass('-ml-1')
})
})
describe('Button Mode', () => {
it('should render as an interactive button when href is omitted', () => {
const onClick = vi.fn()
render(<NavLink {...mockProps} href={undefined} active={true} onClick={onClick} />)
const buttonElement = screen.getByText('Orchestrate').closest('button')
expect(buttonElement).not.toBeNull()
expect(buttonElement).toHaveClass('bg-components-menu-item-bg-active')
expect(buttonElement).toHaveClass('text-text-accent-light-mode-only')
buttonElement?.click()
expect(onClick).toHaveBeenCalledTimes(1)
})
})
})

View File

@@ -14,13 +14,15 @@ export type NavIcon = React.ComponentType<
export type NavLinkProps = {
name: string
href: string
href?: string
iconMap: {
selected: NavIcon
normal: NavIcon
}
mode?: string
disabled?: boolean
active?: boolean
onClick?: () => void
}
const NavLink = ({
@@ -29,6 +31,8 @@ const NavLink = ({
iconMap,
mode = 'expand',
disabled = false,
active,
onClick,
}: NavLinkProps) => {
const segment = useSelectedLayoutSegment()
const formattedSegment = (() => {
@@ -39,8 +43,11 @@ const NavLink = ({
return res
})()
const isActive = href.toLowerCase().split('/')?.pop() === formattedSegment
const isActive = active ?? (href ? href.toLowerCase().split('/')?.pop() === formattedSegment : false)
const NavIcon = isActive ? iconMap.selected : iconMap.normal
const linkClassName = cn(isActive
? 'border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only system-sm-semibold'
: 'text-components-menu-item-text system-sm-medium hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover', 'flex h-8 items-center rounded-lg pl-3 pr-1')
const renderIcon = () => (
<div className={cn(mode !== 'expand' && '-ml-1')}>
@@ -70,13 +77,32 @@ const NavLink = ({
)
}
if (!href) {
return (
<button
key={name}
type="button"
className={linkClassName}
title={mode === 'collapse' ? name : ''}
onClick={onClick}
>
{renderIcon()}
<span
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out', mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0')}
>
{name}
</span>
</button>
)
}
return (
<Link
key={name}
href={href}
className={cn(isActive
? 'border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only system-sm-semibold'
: 'text-components-menu-item-text system-sm-medium hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover', 'flex h-8 items-center rounded-lg pl-3 pr-1')}
className={linkClassName}
title={mode === 'collapse' ? name : ''}
>
{renderIcon()}

View File

@@ -0,0 +1,285 @@
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
import type { CreateSnippetDialogPayload } from '@/app/components/workflow/create-snippet-dialog'
import type { SnippetDetail } from '@/models/snippet'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import SnippetInfoDropdown from '../dropdown'
const mockReplace = vi.fn()
const mockDownloadBlob = vi.fn()
const mockToastSuccess = vi.fn()
const mockToastError = vi.fn()
const mockUpdateMutate = vi.fn()
const mockExportMutateAsync = vi.fn()
const mockDeleteMutate = vi.fn()
let mockDropdownOpen = false
let mockDropdownOnOpenChange: ((open: boolean) => void) | undefined
vi.mock('@/next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
}))
vi.mock('@/utils/download', () => ({
downloadBlob: (args: { data: Blob, fileName: string }) => mockDownloadBlob(args),
}))
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
success: (...args: unknown[]) => mockToastSuccess(...args),
error: (...args: unknown[]) => mockToastError(...args),
},
}))
vi.mock('@/app/components/base/ui/dropdown-menu', () => ({
DropdownMenu: ({
open,
onOpenChange,
children,
}: {
open?: boolean
onOpenChange?: (open: boolean) => void
children: React.ReactNode
}) => {
mockDropdownOpen = !!open
mockDropdownOnOpenChange = onOpenChange
return <div>{children}</div>
},
DropdownMenuTrigger: ({
children,
className,
}: {
children: React.ReactNode
className?: string
}) => (
<button
type="button"
className={className}
onClick={() => mockDropdownOnOpenChange?.(!mockDropdownOpen)}
>
{children}
</button>
),
DropdownMenuContent: ({ children }: { children: React.ReactNode }) => (
mockDropdownOpen ? <div>{children}</div> : null
),
DropdownMenuItem: ({
children,
onClick,
}: {
children: React.ReactNode
onClick?: () => void
}) => (
<button type="button" onClick={onClick}>
{children}
</button>
),
DropdownMenuSeparator: () => <hr />,
}))
vi.mock('@/service/use-snippets', () => ({
useUpdateSnippetMutation: () => ({
mutate: mockUpdateMutate,
isPending: false,
}),
useExportSnippetMutation: () => ({
mutateAsync: mockExportMutateAsync,
isPending: false,
}),
useDeleteSnippetMutation: () => ({
mutate: mockDeleteMutate,
isPending: false,
}),
}))
type MockCreateSnippetDialogProps = {
isOpen: boolean
title?: string
confirmText?: string
initialValue?: {
name?: string
description?: string
icon?: AppIconSelection
}
onClose: () => void
onConfirm: (payload: CreateSnippetDialogPayload) => void
}
vi.mock('@/app/components/workflow/create-snippet-dialog', () => ({
default: ({
isOpen,
title,
confirmText,
initialValue,
onClose,
onConfirm,
}: MockCreateSnippetDialogProps) => {
if (!isOpen)
return null
return (
<div data-testid="create-snippet-dialog">
<div>{title}</div>
<div>{confirmText}</div>
<div>{initialValue?.name}</div>
<div>{initialValue?.description}</div>
<button
type="button"
onClick={() => onConfirm({
name: 'Updated snippet',
description: 'Updated description',
icon: {
type: 'emoji',
icon: '✨',
background: '#FFFFFF',
},
graph: {
nodes: [],
edges: [],
viewport: { x: 0, y: 0, zoom: 1 },
},
})}
>
submit-edit
</button>
<button type="button" onClick={onClose}>close-edit</button>
</div>
)
},
}))
const mockSnippet: SnippetDetail = {
id: 'snippet-1',
name: 'Social Media Repurposer',
description: 'Turn one blog post into multiple social media variations.',
author: 'Dify',
updatedAt: '2026-03-25 10:00',
usage: '12',
icon: '🤖',
iconBackground: '#F0FDF9',
status: undefined,
}
describe('SnippetInfoDropdown', () => {
beforeEach(() => {
vi.clearAllMocks()
mockDropdownOpen = false
mockDropdownOnOpenChange = undefined
})
// Rendering coverage for the menu trigger itself.
describe('Rendering', () => {
it('should render the dropdown trigger button', () => {
render(<SnippetInfoDropdown snippet={mockSnippet} />)
expect(screen.getByRole('button')).toBeInTheDocument()
})
})
// Edit flow should seed the dialog with current snippet info and submit updates.
describe('Edit Snippet', () => {
it('should open the edit dialog and submit snippet updates', async () => {
const user = userEvent.setup()
mockUpdateMutate.mockImplementation((_variables: unknown, options?: { onSuccess?: () => void }) => {
options?.onSuccess?.()
})
render(<SnippetInfoDropdown snippet={mockSnippet} />)
await user.click(screen.getByRole('button'))
await user.click(screen.getByText('snippet.menu.editInfo'))
expect(screen.getByTestId('create-snippet-dialog')).toBeInTheDocument()
expect(screen.getByText('snippet.editDialogTitle')).toBeInTheDocument()
expect(screen.getByText('common.operation.save')).toBeInTheDocument()
expect(screen.getByText(mockSnippet.name)).toBeInTheDocument()
expect(screen.getByText(mockSnippet.description)).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'submit-edit' }))
expect(mockUpdateMutate).toHaveBeenCalledWith({
params: { snippetId: mockSnippet.id },
body: {
name: 'Updated snippet',
description: 'Updated description',
icon_info: {
icon: '✨',
icon_type: 'emoji',
icon_background: '#FFFFFF',
icon_url: undefined,
},
},
}, expect.objectContaining({
onSuccess: expect.any(Function),
onError: expect.any(Function),
}))
expect(mockToastSuccess).toHaveBeenCalledWith('snippet.editDone')
})
})
// Export should call the export hook and download the returned YAML blob.
describe('Export Snippet', () => {
it('should export and download the snippet yaml', async () => {
const user = userEvent.setup()
mockExportMutateAsync.mockResolvedValue('yaml: content')
render(<SnippetInfoDropdown snippet={mockSnippet} />)
await user.click(screen.getByRole('button'))
await user.click(screen.getByText('snippet.menu.exportSnippet'))
await waitFor(() => {
expect(mockExportMutateAsync).toHaveBeenCalledWith({ snippetId: mockSnippet.id })
})
expect(mockDownloadBlob).toHaveBeenCalledWith({
data: expect.any(Blob),
fileName: `${mockSnippet.name}.yml`,
})
})
it('should show an error toast when export fails', async () => {
const user = userEvent.setup()
mockExportMutateAsync.mockRejectedValue(new Error('export failed'))
render(<SnippetInfoDropdown snippet={mockSnippet} />)
await user.click(screen.getByRole('button'))
await user.click(screen.getByText('snippet.menu.exportSnippet'))
await waitFor(() => {
expect(mockToastError).toHaveBeenCalledWith('snippet.exportFailed')
})
})
})
// Delete should require confirmation and redirect after a successful mutation.
describe('Delete Snippet', () => {
it('should confirm deletion and redirect to the snippets list', async () => {
const user = userEvent.setup()
mockDeleteMutate.mockImplementation((_variables: unknown, options?: { onSuccess?: () => void }) => {
options?.onSuccess?.()
})
render(<SnippetInfoDropdown snippet={mockSnippet} />)
await user.click(screen.getByRole('button'))
await user.click(screen.getByText('snippet.menu.deleteSnippet'))
expect(screen.getByText('snippet.deleteConfirmTitle')).toBeInTheDocument()
expect(screen.getByText('snippet.deleteConfirmContent')).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'snippet.menu.deleteSnippet' }))
expect(mockDeleteMutate).toHaveBeenCalledWith({
params: { snippetId: mockSnippet.id },
}, expect.objectContaining({
onSuccess: expect.any(Function),
onError: expect.any(Function),
}))
expect(mockToastSuccess).toHaveBeenCalledWith('snippet.deleted')
expect(mockReplace).toHaveBeenCalledWith('/snippets')
})
})
})

View File

@@ -0,0 +1,62 @@
import type { SnippetDetail } from '@/models/snippet'
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import SnippetInfo from '..'
vi.mock('../dropdown', () => ({
default: () => <div data-testid="snippet-info-dropdown" />,
}))
const mockSnippet: SnippetDetail = {
id: 'snippet-1',
name: 'Social Media Repurposer',
description: 'Turn one blog post into multiple social media variations.',
author: 'Dify',
updatedAt: '2026-03-25 10:00',
usage: '12',
icon: '🤖',
iconBackground: '#F0FDF9',
status: undefined,
}
describe('SnippetInfo', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering tests for the collapsed and expanded sidebar header states.
describe('Rendering', () => {
it('should render the expanded snippet details and dropdown when expand is true', () => {
render(<SnippetInfo expand={true} snippet={mockSnippet} />)
expect(screen.getByText(mockSnippet.name)).toBeInTheDocument()
expect(screen.getByText('snippet.typeLabel')).toBeInTheDocument()
expect(screen.getByText(mockSnippet.description)).toBeInTheDocument()
expect(screen.getByTestId('snippet-info-dropdown')).toBeInTheDocument()
})
it('should hide the expanded-only content when expand is false', () => {
render(<SnippetInfo expand={false} snippet={mockSnippet} />)
expect(screen.queryByText(mockSnippet.name)).not.toBeInTheDocument()
expect(screen.queryByText('snippet.typeLabel')).not.toBeInTheDocument()
expect(screen.queryByText(mockSnippet.description)).not.toBeInTheDocument()
expect(screen.queryByTestId('snippet-info-dropdown')).not.toBeInTheDocument()
})
})
// Edge cases around optional snippet fields should not break the header layout.
describe('Edge Cases', () => {
it('should omit the description block when the snippet has no description', () => {
render(
<SnippetInfo
expand={true}
snippet={{ ...mockSnippet, description: '' }}
/>,
)
expect(screen.getByText(mockSnippet.name)).toBeInTheDocument()
expect(screen.queryByText(mockSnippet.description)).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,197 @@
'use client'
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
import type { SnippetDetail } from '@/models/snippet'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import {
AlertDialog,
AlertDialogActions,
AlertDialogCancelButton,
AlertDialogConfirmButton,
AlertDialogContent,
AlertDialogDescription,
AlertDialogTitle,
} from '@/app/components/base/ui/alert-dialog'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from '@/app/components/base/ui/dropdown-menu'
import { toast } from '@/app/components/base/ui/toast'
import CreateSnippetDialog from '@/app/components/workflow/create-snippet-dialog'
import { useRouter } from '@/next/navigation'
import { useDeleteSnippetMutation, useExportSnippetMutation, useUpdateSnippetMutation } from '@/service/use-snippets'
import { cn } from '@/utils/classnames'
import { downloadBlob } from '@/utils/download'
type SnippetInfoDropdownProps = {
snippet: SnippetDetail
}
const FALLBACK_ICON: AppIconSelection = {
type: 'emoji',
icon: '🤖',
background: '#FFEAD5',
}
const SnippetInfoDropdown = ({ snippet }: SnippetInfoDropdownProps) => {
const { t } = useTranslation('snippet')
const { replace } = useRouter()
const [open, setOpen] = React.useState(false)
const [isEditDialogOpen, setIsEditDialogOpen] = React.useState(false)
const [isDeleteDialogOpen, setIsDeleteDialogOpen] = React.useState(false)
const updateSnippetMutation = useUpdateSnippetMutation()
const exportSnippetMutation = useExportSnippetMutation()
const deleteSnippetMutation = useDeleteSnippetMutation()
const initialValue = React.useMemo(() => ({
name: snippet.name,
description: snippet.description,
icon: snippet.icon
? {
type: 'emoji' as const,
icon: snippet.icon,
background: snippet.iconBackground || FALLBACK_ICON.background,
}
: FALLBACK_ICON,
}), [snippet.description, snippet.icon, snippet.iconBackground, snippet.name])
const handleOpenEditDialog = React.useCallback(() => {
setOpen(false)
setIsEditDialogOpen(true)
}, [])
const handleExportSnippet = React.useCallback(async () => {
setOpen(false)
try {
const data = await exportSnippetMutation.mutateAsync({ snippetId: snippet.id })
const file = new Blob([data], { type: 'application/yaml' })
downloadBlob({ data: file, fileName: `${snippet.name}.yml` })
}
catch {
toast.error(t('exportFailed'))
}
}, [exportSnippetMutation, snippet.id, snippet.name, t])
const handleEditSnippet = React.useCallback(async ({ name, description, icon }: {
name: string
description: string
icon: AppIconSelection
}) => {
updateSnippetMutation.mutate({
params: { snippetId: snippet.id },
body: {
name,
description: description || undefined,
icon_info: {
icon: icon.type === 'emoji' ? icon.icon : icon.fileId,
icon_type: icon.type,
icon_background: icon.type === 'emoji' ? icon.background : undefined,
icon_url: icon.type === 'image' ? icon.url : undefined,
},
},
}, {
onSuccess: () => {
toast.success(t('editDone'))
setIsEditDialogOpen(false)
},
onError: (error) => {
toast.error(error instanceof Error ? error.message : t('editFailed'))
},
})
}, [snippet.id, t, updateSnippetMutation])
const handleDeleteSnippet = React.useCallback(() => {
deleteSnippetMutation.mutate({
params: { snippetId: snippet.id },
}, {
onSuccess: () => {
toast.success(t('deleted'))
setIsDeleteDialogOpen(false)
replace('/snippets')
},
onError: (error) => {
toast.error(error instanceof Error ? error.message : t('deleteFailed'))
},
})
}, [deleteSnippetMutation, replace, snippet.id, t])
return (
<>
<DropdownMenu open={open} onOpenChange={setOpen}>
<DropdownMenuTrigger
className={cn('action-btn action-btn-m size-6 rounded-md text-text-tertiary', open && 'bg-state-base-hover text-text-secondary')}
>
<span aria-hidden className="i-ri-more-fill size-4" />
</DropdownMenuTrigger>
<DropdownMenuContent
placement="bottom-end"
sideOffset={4}
popupClassName="w-[180px] p-1"
>
<DropdownMenuItem className="mx-0 gap-2" onClick={handleOpenEditDialog}>
<span aria-hidden className="i-ri-edit-line size-4 shrink-0 text-text-tertiary" />
<span className="grow">{t('menu.editInfo')}</span>
</DropdownMenuItem>
<DropdownMenuItem className="mx-0 gap-2" onClick={handleExportSnippet}>
<span aria-hidden className="i-ri-download-2-line size-4 shrink-0 text-text-tertiary" />
<span className="grow">{t('menu.exportSnippet')}</span>
</DropdownMenuItem>
<DropdownMenuSeparator className="!my-1 bg-divider-subtle" />
<DropdownMenuItem
className="mx-0 gap-2"
destructive
onClick={() => {
setOpen(false)
setIsDeleteDialogOpen(true)
}}
>
<span aria-hidden className="i-ri-delete-bin-line size-4 shrink-0" />
<span className="grow">{t('menu.deleteSnippet')}</span>
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
{isEditDialogOpen && (
<CreateSnippetDialog
isOpen={isEditDialogOpen}
initialValue={initialValue}
title={t('editDialogTitle')}
confirmText={t('operation.save', { ns: 'common' })}
isSubmitting={updateSnippetMutation.isPending}
onClose={() => setIsEditDialogOpen(false)}
onConfirm={handleEditSnippet}
/>
)}
<AlertDialog open={isDeleteDialogOpen} onOpenChange={setIsDeleteDialogOpen}>
<AlertDialogContent className="w-[400px]">
<div className="space-y-2 p-6">
<AlertDialogTitle className="text-text-primary title-lg-semi-bold">
{t('deleteConfirmTitle')}
</AlertDialogTitle>
<AlertDialogDescription className="text-text-tertiary system-sm-regular">
{t('deleteConfirmContent')}
</AlertDialogDescription>
</div>
<AlertDialogActions className="pt-0">
<AlertDialogCancelButton>
{t('operation.cancel', { ns: 'common' })}
</AlertDialogCancelButton>
<AlertDialogConfirmButton
loading={deleteSnippetMutation.isPending}
onClick={handleDeleteSnippet}
>
{t('menu.deleteSnippet')}
</AlertDialogConfirmButton>
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
</>
)
}
export default React.memo(SnippetInfoDropdown)

View File

@@ -0,0 +1,55 @@
'use client'
import type { SnippetDetail } from '@/models/snippet'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'
import { cn } from '@/utils/classnames'
import SnippetInfoDropdown from './dropdown'
type SnippetInfoProps = {
expand: boolean
snippet: SnippetDetail
}
const SnippetInfo = ({
expand,
snippet,
}: SnippetInfoProps) => {
const { t } = useTranslation('snippet')
return (
<div className={cn('flex flex-col', expand ? 'px-2 pb-1 pt-2' : 'p-1')}>
<div className={cn('flex flex-col', expand ? 'gap-2 rounded-xl p-2' : '')}>
<div className={cn('flex', expand ? 'items-center justify-between' : 'items-start gap-3')}>
<div className={cn('shrink-0', !expand && 'ml-1')}>
<AppIcon
size={expand ? 'large' : 'small'}
iconType="emoji"
icon={snippet.icon}
background={snippet.iconBackground}
/>
</div>
{expand && <SnippetInfoDropdown snippet={snippet} />}
</div>
{expand && (
<div className="min-w-0">
<div className="truncate text-text-secondary system-md-semibold">
{snippet.name}
</div>
<div className="pt-1 text-text-tertiary system-2xs-medium-uppercase">
{t('typeLabel')}
</div>
</div>
)}
{expand && snippet.description && (
<p className="line-clamp-3 break-words text-text-tertiary system-xs-regular">
{snippet.description}
</p>
)}
</div>
</div>
)
}
export default React.memo(SnippetInfo)

View File

@@ -0,0 +1,69 @@
import { useStore } from '../store'
const resetStore = () => {
useStore.setState({
appDetail: undefined,
appSidebarExpand: '',
currentLogItem: undefined,
currentLogModalActiveTab: 'DETAIL',
showPromptLogModal: false,
showAgentLogModal: false,
showMessageLogModal: false,
showAppConfigureFeaturesModal: false,
})
}
describe('app store', () => {
beforeEach(() => {
resetStore()
})
it('should expose the default state', () => {
expect(useStore.getState()).toEqual(expect.objectContaining({
appDetail: undefined,
appSidebarExpand: '',
currentLogItem: undefined,
currentLogModalActiveTab: 'DETAIL',
showPromptLogModal: false,
showAgentLogModal: false,
showMessageLogModal: false,
showAppConfigureFeaturesModal: false,
}))
})
it('should update every mutable field through its actions', () => {
const appDetail = { id: 'app-1' } as ReturnType<typeof useStore.getState>['appDetail']
const currentLogItem = { id: 'message-1' } as ReturnType<typeof useStore.getState>['currentLogItem']
useStore.getState().setAppDetail(appDetail)
useStore.getState().setAppSidebarExpand('logs')
useStore.getState().setCurrentLogItem(currentLogItem)
useStore.getState().setCurrentLogModalActiveTab('MESSAGE')
useStore.getState().setShowPromptLogModal(true)
useStore.getState().setShowAgentLogModal(true)
useStore.getState().setShowAppConfigureFeaturesModal(true)
expect(useStore.getState()).toEqual(expect.objectContaining({
appDetail,
appSidebarExpand: 'logs',
currentLogItem,
currentLogModalActiveTab: 'MESSAGE',
showPromptLogModal: true,
showAgentLogModal: true,
showAppConfigureFeaturesModal: true,
}))
})
it('should reset the active tab when the message log modal closes', () => {
useStore.getState().setCurrentLogModalActiveTab('TRACE')
useStore.getState().setShowMessageLogModal(true)
expect(useStore.getState().showMessageLogModal).toBe(true)
expect(useStore.getState().currentLogModalActiveTab).toBe('TRACE')
useStore.getState().setShowMessageLogModal(false)
expect(useStore.getState().showMessageLogModal).toBe(false)
expect(useStore.getState().currentLogModalActiveTab).toBe('DETAIL')
})
})

View File

@@ -1,6 +1,6 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import BatchAction from './batch-action'
import BatchAction from '../batch-action'
describe('BatchAction', () => {
const baseProps = {

View File

@@ -1,6 +1,6 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import EmptyElement from './empty-element'
import EmptyElement from '../empty-element'
describe('EmptyElement', () => {
it('should render the empty state copy and supporting icon', () => {

View File

@@ -1,12 +1,12 @@
import type { UseQueryResult } from '@tanstack/react-query'
import type { Mock } from 'vitest'
import type { QueryParam } from './filter'
import type { QueryParam } from '../filter'
import type { AnnotationsCountResponse } from '@/models/log'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import * as useLogModule from '@/service/use-log'
import Filter from './filter'
import Filter from '../filter'
vi.mock('@/service/use-log')

View File

@@ -1,5 +1,6 @@
/* eslint-disable ts/no-explicit-any */
import type { Mock } from 'vitest'
import type { AnnotationItem } from './type'
import type { AnnotationItem } from '../type'
import type { App } from '@/types/app'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
@@ -9,13 +10,16 @@ import {
addAnnotation,
delAnnotation,
delAnnotations,
editAnnotation,
fetchAnnotationConfig,
fetchAnnotationList,
queryAnnotationJobStatus,
updateAnnotationScore,
updateAnnotationStatus,
} from '@/service/annotation'
import { AppModeEnum } from '@/types/app'
import Annotation from './index'
import { JobStatus } from './type'
import Annotation from '../index'
import { AnnotationEnableStatus, JobStatus } from '../type'
vi.mock('ahooks', () => ({
useDebounce: (value: any) => value,
@@ -37,29 +41,32 @@ vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('./filter', () => ({
vi.mock('../filter', () => ({
default: ({ children }: { children: React.ReactNode }) => (
<div data-testid="filter">{children}</div>
),
}))
vi.mock('./empty-element', () => ({
vi.mock('../empty-element', () => ({
default: () => <div data-testid="empty-element" />,
}))
vi.mock('./header-opts', () => ({
vi.mock('../header-opts', () => ({
default: (props: any) => (
<div data-testid="header-opts">
<button data-testid="trigger-add" onClick={() => props.onAdd({ question: 'new question', answer: 'new answer' })}>
add
</button>
<button data-testid="trigger-added" onClick={() => props.onAdded()}>
added
</button>
</div>
),
}))
let latestListProps: any
vi.mock('./list', () => ({
vi.mock('../list', () => ({
default: (props: any) => {
latestListProps = props
if (!props.list.length)
@@ -74,7 +81,7 @@ vi.mock('./list', () => ({
},
}))
vi.mock('./view-annotation-modal', () => ({
vi.mock('../view-annotation-modal', () => ({
default: (props: any) => {
if (!props.isShow)
return null
@@ -82,14 +89,40 @@ vi.mock('./view-annotation-modal', () => ({
<div data-testid="view-modal">
<div>{props.item.question}</div>
<button data-testid="view-modal-remove" onClick={props.onRemove}>remove</button>
<button data-testid="view-modal-save" onClick={() => props.onSave('Edited question', 'Edited answer')}>save</button>
<button data-testid="view-modal-close" onClick={props.onHide}>close</button>
</div>
)
},
}))
vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => ({ default: (props: any) => props.isShow ? <div data-testid="config-modal" /> : null }))
vi.mock('@/app/components/billing/annotation-full/modal', () => ({ default: (props: any) => props.show ? <div data-testid="annotation-full-modal" /> : null }))
vi.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => ({
default: (props: any) => props.isShow
? (
<div data-testid="config-modal">
<button
data-testid="config-save"
onClick={() => props.onSave({
embedding_model_name: 'next-model',
embedding_provider_name: 'next-provider',
}, 0.7)}
>
save-config
</button>
<button data-testid="config-hide" onClick={props.onHide}>hide-config</button>
</div>
)
: null,
}))
vi.mock('@/app/components/billing/annotation-full/modal', () => ({
default: (props: any) => props.show
? (
<div data-testid="annotation-full-modal">
<button data-testid="hide-annotation-full-modal" onClick={props.onHide}>hide-full</button>
</div>
)
: null,
}))
const mockNotify = vi.fn()
vi.spyOn(toast, 'success').mockImplementation((message, options) => {
@@ -111,9 +144,12 @@ vi.spyOn(toast, 'info').mockImplementation((message, options) => {
const addAnnotationMock = addAnnotation as Mock
const delAnnotationMock = delAnnotation as Mock
const delAnnotationsMock = delAnnotations as Mock
const editAnnotationMock = editAnnotation as Mock
const fetchAnnotationConfigMock = fetchAnnotationConfig as Mock
const fetchAnnotationListMock = fetchAnnotationList as Mock
const queryAnnotationJobStatusMock = queryAnnotationJobStatus as Mock
const updateAnnotationScoreMock = updateAnnotationScore as Mock
const updateAnnotationStatusMock = updateAnnotationStatus as Mock
const useProviderContextMock = useProviderContext as Mock
const appDetail = {
@@ -146,6 +182,9 @@ describe('Annotation', () => {
})
fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 })
queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed })
updateAnnotationStatusMock.mockResolvedValue({ job_id: 'job-1' })
updateAnnotationScoreMock.mockResolvedValue(undefined)
editAnnotationMock.mockResolvedValue(undefined)
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 0 },
@@ -251,4 +290,166 @@ describe('Annotation', () => {
expect(latestListProps.selectedIds).toEqual([annotation.id])
})
})
it('should show the annotation-full modal when enabling annotations exceeds the plan quota', async () => {
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 10 },
total: { annotatedResponse: 10 },
},
enableBilling: true,
})
renderComponent()
const toggle = await screen.findByRole('switch')
fireEvent.click(toggle)
expect(screen.getByTestId('annotation-full-modal')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('hide-annotation-full-modal'))
expect(screen.queryByTestId('annotation-full-modal')).not.toBeInTheDocument()
})
it('should disable annotations and refetch config after the async job completes', async () => {
fetchAnnotationConfigMock.mockResolvedValueOnce({
id: 'config-id',
enabled: true,
embedding_model: {
embedding_model_name: 'model',
embedding_provider_name: 'provider',
},
score_threshold: 0.5,
}).mockResolvedValueOnce({
id: 'config-id',
enabled: false,
embedding_model: {
embedding_model_name: 'model',
embedding_provider_name: 'provider',
},
score_threshold: 0.5,
})
renderComponent()
const toggle = await screen.findByRole('switch')
await waitFor(() => {
expect(toggle).toHaveAttribute('aria-checked', 'true')
})
fireEvent.click(toggle)
await waitFor(() => {
expect(updateAnnotationStatusMock).toHaveBeenCalledWith(
appDetail.id,
AnnotationEnableStatus.disable,
expect.objectContaining({
embedding_model_name: 'model',
embedding_provider_name: 'provider',
}),
0.5,
)
expect(queryAnnotationJobStatusMock).toHaveBeenCalledWith(appDetail.id, AnnotationEnableStatus.disable, 'job-1')
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
message: 'common.api.actionSuccess',
type: 'success',
}))
})
})
it('should save annotation config changes and update the score when the modal confirms', async () => {
fetchAnnotationConfigMock.mockResolvedValue({
id: 'config-id',
enabled: false,
embedding_model: {
embedding_model_name: 'model',
embedding_provider_name: 'provider',
},
score_threshold: 0.5,
})
renderComponent()
const toggle = await screen.findByRole('switch')
fireEvent.click(toggle)
expect(screen.getByTestId('config-modal')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('config-save'))
await waitFor(() => {
expect(updateAnnotationStatusMock).toHaveBeenCalledWith(
appDetail.id,
AnnotationEnableStatus.enable,
{
embedding_model_name: 'next-model',
embedding_provider_name: 'next-provider',
},
0.7,
)
expect(updateAnnotationScoreMock).toHaveBeenCalledWith(appDetail.id, 'config-id', 0.7)
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
message: 'common.api.actionSuccess',
type: 'success',
}))
})
})
it('should refresh the list from the header shortcut and allow saving or closing the view modal', async () => {
const annotation = createAnnotation()
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
renderComponent()
await screen.findByTestId('list')
fireEvent.click(screen.getByTestId('list-view'))
fireEvent.click(screen.getByTestId('view-modal-save'))
await waitFor(() => {
expect(editAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id, {
question: 'Edited question',
answer: 'Edited answer',
})
})
fireEvent.click(screen.getByTestId('view-modal-close'))
expect(screen.queryByTestId('view-modal')).not.toBeInTheDocument()
fireEvent.click(screen.getByTestId('trigger-added'))
expect(fetchAnnotationListMock).toHaveBeenCalled()
})
it('should clear selections on cancel and hide the config modal when requested', async () => {
const annotation = createAnnotation()
fetchAnnotationConfigMock.mockResolvedValue({
id: 'config-id',
enabled: true,
embedding_model: {
embedding_model_name: 'model',
embedding_provider_name: 'provider',
},
score_threshold: 0.5,
})
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
renderComponent()
await screen.findByTestId('list')
await act(async () => {
latestListProps.onSelectedIdsChange([annotation.id])
})
await act(async () => {
latestListProps.onCancel()
})
expect(latestListProps.selectedIds).toEqual([])
const configButton = document.querySelector('.action-btn') as HTMLButtonElement
fireEvent.click(configButton)
expect(await screen.findByTestId('config-modal')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('config-hide'))
expect(screen.queryByTestId('config-modal')).not.toBeInTheDocument()
})
})

View File

@@ -1,7 +1,7 @@
import type { AnnotationItem } from './type'
import type { AnnotationItem } from '../type'
import { fireEvent, render, screen, within } from '@testing-library/react'
import * as React from 'react'
import List from './list'
import List from '../list'
const mockFormatTime = vi.fn(() => 'formatted-time')

View File

@@ -2,7 +2,7 @@ import type { Mock } from 'vitest'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useProviderContext } from '@/context/provider-context'
import AddAnnotationModal from './index'
import AddAnnotationModal from '../index'
vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),

View File

@@ -1,6 +1,6 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import EditItem, { EditItemType } from './index'
import EditItem, { EditItemType } from '../index'
describe('AddAnnotationModal/EditItem', () => {
it('should render query inputs with user avatar and placeholder strings', () => {

View File

@@ -1,10 +1,11 @@
/* eslint-disable ts/no-explicit-any */
import type { Mock } from 'vitest'
import type { Locale } from '@/i18n-config'
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import { useLocale } from '@/context/i18n'
import { LanguagesSupported } from '@/i18n-config/language'
import CSVDownload from './csv-downloader'
import CSVDownload from '../csv-downloader'
const downloaderProps: any[] = []

View File

@@ -1,7 +1,7 @@
import type { Props } from './csv-uploader'
import type { Props } from '../csv-uploader'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import CSVUploader from './csv-uploader'
import CSVUploader from '../csv-uploader'
const toastMocks = vi.hoisted(() => ({
notify: vi.fn(),
@@ -75,6 +75,20 @@ describe('CSVUploader', () => {
expect(dropZone.className).not.toContain('border-components-dropzone-border-accent')
})
it('should handle drag over and clear dragging state when leaving through the overlay', () => {
renderComponent()
const { dropZone, dropContainer } = getDropElements()
fireEvent.dragEnter(dropContainer)
const dragLayer = dropContainer.querySelector('.absolute') as HTMLDivElement
fireEvent.dragOver(dropContainer)
fireEvent.dragLeave(dragLayer)
expect(dropZone.className).not.toContain('border-components-dropzone-border-accent')
expect(dropZone.className).not.toContain('bg-components-dropzone-bg-accent')
})
it('should ignore drop events without dataTransfer', () => {
renderComponent()
const { dropContainer } = getDropElements()

View File

@@ -1,10 +1,10 @@
import type { Mock } from 'vitest'
import type { IBatchModalProps } from './index'
import type { IBatchModalProps } from '../index'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useProviderContext } from '@/context/provider-context'
import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation'
import BatchModal, { ProcessStatus } from './index'
import BatchModal, { ProcessStatus } from '../index'
vi.mock('@/service/annotation', () => ({
annotationBatchImport: vi.fn(),
@@ -15,13 +15,13 @@ vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('./csv-downloader', () => ({
vi.mock('../csv-downloader', () => ({
default: () => <div data-testid="csv-downloader-stub" />,
}))
let lastUploadedFile: File | undefined
vi.mock('./csv-uploader', () => ({
vi.mock('../csv-uploader', () => ({
default: ({ file, updateFile }: { file?: File, updateFile: (file?: File) => void }) => (
<div>
<button

View File

@@ -1,6 +1,6 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import ClearAllAnnotationsConfirmModal from './index'
import ClearAllAnnotationsConfirmModal from '../index'
vi.mock('react-i18next', () => ({
useTranslation: () => ({

View File

@@ -1,7 +1,7 @@
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { toast } from '@/app/components/base/ui/toast'
import EditAnnotationModal from './index'
import EditAnnotationModal from '../index'
const { mockAddAnnotation, mockEditAnnotation } = vi.hoisted(() => ({
mockAddAnnotation: vi.fn(),
@@ -51,11 +51,6 @@ describe('EditAnnotationModal', () => {
onRemove: vi.fn(),
}
afterAll(() => {
toastSuccessSpy.mockRestore()
toastErrorSpy.mockRestore()
})
beforeEach(() => {
vi.clearAllMocks()
mockAddAnnotation.mockResolvedValue({
@@ -65,6 +60,11 @@ describe('EditAnnotationModal', () => {
mockEditAnnotation.mockResolvedValue({})
})
afterAll(() => {
toastSuccessSpy.mockRestore()
toastErrorSpy.mockRestore()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render modal when isShow is true', () => {

Some files were not shown because too many files have changed in this diff Show More