Compare commits

..

114 Commits

Author SHA1 Message Date
longbingljw
9f586c44d8 fix:env_file should read .env (#67)
* config adapt revert

* ci test

* fix mysql migration test

* fix

* fix

* lint fix

* fix ob config

* fix

* fix

* fix

* test over

* test

* fix

* fix

* fix style

* test over

* retain gin for pg

* gin for pg

* uuid defalut in versions

* ci test

* ci test

* fix

* fix

* fix

* fix

* pg josnb

* fix

* fix

* add seekdb

* test over

* test over

* fix env_fix
2025-11-17 10:52:59 +08:00
longbingljw
5322f3bbd4 feat:add seekdb (#66)
* config adapt revert

* ci test

* fix mysql migration test

* fix

* fix

* lint fix

* fix ob config

* fix

* fix

* fix

* test over

* test

* fix

* fix

* fix style

* test over

* retain gin for pg

* gin for pg

* uuid defalut in versions

* ci test

* ci test

* fix

* fix

* fix

* fix

* pg josnb

* fix

* fix

* add seekdb

* test over

* test over
2025-11-16 18:39:35 +08:00
longbingljw
6433ac8209 feat:json metadat filter adapt (#65)
* config adapt revert

* ci test

* fix mysql migration test

* fix

* fix

* lint fix

* fix ob config

* fix

* fix

* fix

* test over

* test

* fix

* fix

* fix style

* test over

* retain gin for pg

* gin for pg

* uuid defalut in versions

* ci test

* ci test

* fix

* fix

* fix

* fix

* pg josnb

* fix
2025-11-15 22:29:59 +08:00
longbingljw
84935b9169 revert:public schema (#64)
* config adapt revert

* ci test

* fix mysql migration test

* fix

* fix

* lint fix

* fix ob config

* fix

* fix

* fix

* test over
2025-11-15 04:48:09 +08:00
longbingljw
eceaea68b1 fix:ci (#63)
* fix ci

* fix ci

* fix
2025-11-14 14:44:13 +08:00
longbingljw
26cfccb84b fix (#62) 2025-11-14 11:20:05 +08:00
longbingljw
3042b69e77 fix:style and ci (#60)
* mysql adaptation

* fix ci

* fix
2025-11-13 22:33:05 +08:00
longbingljw
49d5637b3c mysql adaptation (#59) 2025-11-13 21:27:41 +08:00
yangzheli
20403c69b2 refactor(web): remove redundant add-tool-modal components and related code (#27996) 2025-11-13 20:21:04 +08:00
hoffer
ffc04f2a9b fix: StreamableHTTPTransport got invalid json exception when receive a ping event from mcp server #28111 (#28116) 2025-11-13 20:19:48 +08:00
Asuka Minato
d1580791e4 TypedBase + TypedDict (#28137)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-13 20:18:51 +08:00
NeatGuyCoding
c74eb4fcf3 minor fix(rag): return early when pushing empty tasks to avoid Redis DataError (#28027)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-11-13 20:18:11 +08:00
NeatGuyCoding
a798534337 fix(web): fix unit promotion in formatNumberAbbreviated (#27918)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-11-13 20:17:26 +08:00
GuanMu
470883858e fix: adjust padding in AgentNode and NodeComponent for consistent layout (#28175) 2025-11-13 20:16:56 +08:00
GuanMu
4f4911686d fix: update start-worker alias to include additional queues for bette… (#28179) 2025-11-13 20:16:44 +08:00
GuanMu
6d479dcdbb fix: update package manager version to 10.22.0 (#28181) 2025-11-13 20:16:00 +08:00
zhsama
24348c40a6 feat: enhance start node metadata to be undeletable in chat mode (#28173) 2025-11-13 18:11:15 +08:00
yihong
a39b50adbb fix: skip tests if no database run (#28102)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-13 15:57:13 +08:00
李龙飞
81832c14ee Fix: Correctly handle merged cells in DOCX tables to prevent content duplication and loss (#27871)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-13 15:56:24 +08:00
zhsama
b86022c64a feat: add draft trigger detection to app model and UI (#28163)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-13 15:43:58 +08:00
breath57
45e816a9f6 fix(knowledge-base): regenerate child chunks not working completely (#27934) 2025-11-13 15:36:27 +08:00
Joel
667b1c37a3 fix: can still invite when api is pending (#28161) 2025-11-13 15:28:32 +08:00
Chen Yu
b75d533f9b fix(moderation): change OpenAI moderation model to omni-moderation-la… (#28119)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-13 15:21:44 +08:00
CrabSAMA
aece55d82f fix: fixed error when clear value of INTEGER and FLOAT type (#27954)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-13 15:21:34 +08:00
kenwoodjw
c432b398f4 fix: missing pipeline_templates.json when HOSTED_FETCH_PIPELINE_TEMPLATES_MODE is builtin (#27946)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-13 15:04:35 +08:00
katakyo
9cb2645793 fix: update input field width for retry configuration in RetryOnPanel (#28142) 2025-11-13 15:00:22 +08:00
ye4241
6ac61bd585 fix: correct spelling of "模板" in translation files (#28151) 2025-11-13 14:58:10 +08:00
非法操作
b02165ffe6 fix: inconsistent behaviour of zoom in button and shortcut (#27944) 2025-11-13 14:37:27 +08:00
Asuka Minato
6c576e2c66 add doc (#28016)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-13 13:38:45 +09:00
yangzheli
b0e7e7752f refactor(web): reuse the same edit-custom-collection-modal component, and fix the pop up error (#28003) 2025-11-13 11:44:21 +08:00
mnasrautinno
2799b79e8c fix: app's ai site text to speech api (#28091) 2025-11-13 11:44:04 +08:00
Maries
805a1479f9 fix: simplify graph structure validation in WorkflowService (#28146)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-13 10:59:31 +08:00
-LAN-
fe6538b08d chore: disable workflow logs auto-cleanup by default (#28136)
This PR changes the default value of `WORKFLOW_LOG_CLEANUP_ENABLED` from `true` to `false` across all configuration files.

## Motivation

Setting the default to `false` provides safer default behavior by:

- Preventing unintended data loss for new installations
- Giving users explicit control over when to enable log cleanup
- Following the opt-in principle for data deletion features

Users who need automatic cleanup can enable it by setting `WORKFLOW_LOG_CLEANUP_ENABLED=true` in their configuration.
2025-11-12 22:55:02 +08:00
Asuka Minato
1bbb9d6644 convert to TypeBase (#27935)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-12 21:50:13 +08:00
Gritty_dev
5c06e285ec test: create some hooks and utils test script, modified clipboard test script (#27928) 2025-11-12 21:47:06 +08:00
Gen Sato
19c92fd670 Add file type validation to paste upload (#28017) 2025-11-12 19:27:56 +08:00
非法操作
6026bd873b fix: variable assigner can't assign float number (#28068) 2025-11-12 19:27:36 +08:00
Bowen Liang
1369119a0c fix: determine cpu cores determination in baseedpyright-check script on macos (#28058) 2025-11-12 19:27:27 +08:00
Yeuoly
b76e17b25d feat: introduce trigger functionality (#27644)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: Stream <Stream_2@qq.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
Co-authored-by: Harry <xh001x@hotmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
Co-authored-by: hjlarry <hjlarry@163.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WTW0313 <twwu@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-12 17:59:37 +08:00
Jyong
ca7794305b add transform-datasource-credentials command online check (#28124)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
2025-11-12 17:13:44 +08:00
QuantumGhost
fd255e81e1 feat(api): Introduce WorkflowResumptionContext for pause state management (#28122)
Certain metadata (including but not limited to `InvokeFrom`, `call_depth`, and `streaming`)  is required when resuming a paused workflow. However, these fields are not part of `GraphRuntimeState` and were not saved in the previous
 implementation of  `PauseStatePersistenceLayer`.

This commit addresses this limitation by introducing a `WorkflowResumptionContext` model that wraps both the `*GenerateEntity` and `GraphRuntimeState`. This approach provides:

- A structured container for all necessary resumption data
- Better separation of concerns between execution state and persistence
- Enhanced extensibility for future metadata additions
- Clearer naming that distinguishes from `GraphRuntimeState`

The `WorkflowResumptionContext` model makes extending the pause state easier while maintaining backward compatibility and proper version management for the entire execution state ecosystem.

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-12 17:00:02 +08:00
Joel
09d31d1263 chore: improve the user experience of not login into apps (#28120) 2025-11-12 16:47:45 +08:00
Jyong
47dc26f011 fix document index test (#28113) 2025-11-12 16:00:10 +08:00
湛露先生
123bb3ec08 When graph_engine worker run exception, keep the node_id for deep res… (#26205)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-11-12 15:03:45 +08:00
Joel
90f77282e3 chore: not SaaS version can query long log time range (#28109) 2025-11-12 14:45:56 +08:00
Jyong
5208867ccc fix document enable (#28081) 2025-11-11 17:50:45 +08:00
lyzno1
edc7ccc795 chore: add type-check to pre-commit (#28005) 2025-11-11 16:14:39 +08:00
Ali Saleh
c9798f6425 fix(api): Trace Hierarchy, Span Status, and Broken Workflow for Arize & Phoenix Integration (#27937)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-11 11:49:19 +08:00
crazywoola
20ecf7f1d0 chore: remove unused enterprise bot from the readme (#28073)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-11 10:52:27 +08:00
github-actions[bot]
9dcb780fcb chore: translate i18n files and update type definitions (#28054)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
2025-11-11 09:32:53 +08:00
Will
1cb7b09933 chore: Remove trailing space from migration filename (#28040) 2025-11-11 09:32:42 +08:00
Joel
2c62a77cf4 Chore: change query log time range (#28052)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-10 18:39:12 +08:00
QuantumGhost
b9bc48d8dd feat(api): Introduce Broadcast Channel (#27835)
This PR introduces a `BroadcastChannel` abstraction with broadcasting and at-most once delivery semantics, serving as the communication component between celery worker and API server.

It also includes a reference implementation backed by Redis PubSub.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-10 17:23:21 +08:00
Will
ed234e311b fix workflow default updated_at (#28047) 2025-11-10 18:20:38 +09:00
huangzhuo1949
9843fec393 fix: elasticsearch_vector version (#28028)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-10 13:17:13 +09:00
Will
aa4cabdeb5 feat: Add Audio Content Support for MCP Tools (#27979) 2025-11-10 10:12:11 +08:00
NeatGuyCoding
eea713b668 Fix typo in weaviate comment, improve time test precision, and add security tests for get-icon utility (#27919)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-10 10:11:54 +08:00
dependabot[bot]
fc62538a94 chore(deps): bump scipy-stubs from 1.16.2.3 to 1.16.3.0 in /api (#28025)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-10 09:54:56 +08:00
Asuka Minato
7994144df7 add onupdate=func.current_timestamp() (#28014)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-10 01:48:52 +09:00
Kenn
e153c483b6 fix: the model list encountered two children with the same key (#27956)
Co-authored-by: haokai <haokai@shuwen.com>
2025-11-09 21:39:59 +08:00
wangxiaolei
422bb4d4bb fix: fix https://github.com/langgenius/dify/issues/27939 (#27985) 2025-11-09 21:39:05 +08:00
OneZero-Y
87a80d7613 docs: clarify how to obtain workflow_id for version execution (#28007)
Signed-off-by: OneZero-Y <aukovyps@163.com>
2025-11-09 21:38:06 +08:00
kenwoodjw
e91105ca87 fix: bump brotli to 1.2.0 resloved CVE-2025-6176 (#27950)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-11-07 15:57:29 +08:00
hj24
37903722fe refactor: implement tenant self queue for rag tasks (#27559)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-11-06 21:25:50 +08:00
QuantumGhost
f4c82d0010 fix(api): fix VariablePool.get adding unexpected keys to variable_dictionary (#26767)
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-06 18:30:35 +08:00
NFish
fe50093c18 fix: prevent fetch version info in enterprise edition (#27923) 2025-11-06 17:59:53 +08:00
Jyong
4317af1e90 fix jina reader transform (#27922) 2025-11-06 17:35:53 +08:00
red_sun
61a0fcc2ea fix agent putout the output of workflow-tool twice (#26835) (#27087) 2025-11-06 09:41:05 +08:00
Jyong
f627348b11 fix jina reader creadential migration command (#27883) 2025-11-05 18:42:07 +08:00
Cursx
87fb9a6b69 fix Version 2.0.0-beta.2: Chat annotations Api Error #25506 (#27206)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2025-11-05 17:37:19 +08:00
Yongtao Huang
97a2e2ec2e Fix: correct DraftWorkflowApi.post response model (#27289)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-05 17:20:40 +08:00
Boris Polonsky
68d357d7f6 Add WEAVIATE_GRPC_ENDPOINT as designed in weaviate migration guide (#27861)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-05 17:19:08 +08:00
crazywoola
a103ad3ee7 bump vite to 6.4.1 (#27877) 2025-11-05 16:33:19 +08:00
wangjifeng
f65d5a9761 Fix/template transformer line number (#27867)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-05 15:21:47 +08:00
github-actions[bot]
6e0a5f5bbd chore: translate i18n files and update type definitions (#27868)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-05 15:17:53 +08:00
crazywoola
22f858152f feat: change feedback to forum (#27862) 2025-11-05 14:51:57 +08:00
Gritty_dev
775d2e14fc test: create new test scripts and update some existing test scripts o… (#27850) 2025-11-05 11:09:24 +08:00
johnny0120
744b287e67 fix: avoid passing empty uniqueIdentifier to InstallFromMarketplace (#27802)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-05 10:22:22 +08:00
crazywoola
c0fc5d98f0 fix: installation_id is missing when in tools page (#27849) 2025-11-05 10:19:12 +08:00
Elliott
08ea79d730 fix(web): increase z-index of PortalToFollowElemContent (#27823) 2025-11-05 09:32:15 +08:00
yangzheli
f31b821cc0 fix(web): improve the consistency of the inputs-form UI (#27837) 2025-11-05 09:29:13 +08:00
Novice
34be16874f feat: add validation to prevent saving empty opening statement in conversation opener modal (#27843) 2025-11-05 09:28:49 +08:00
aka James4u
e9738b891f test: adding some web tests (#27792) 2025-11-04 21:06:44 +08:00
zhengchangchun
829796514a fix:knowledge base reference information is overwritten when using mu… (#27799)
Co-authored-by: zhengchangchun <zhengchangchun@corp.netease.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-04 16:40:44 +08:00
Novice
ef1db35f80 feat: implement file extension blacklist for upload security (#27540) 2025-11-04 15:45:22 +08:00
Cursx
f9c67621ca fix agent putout the output of workflow-tool twice (#26835) (#27706) 2025-11-04 14:24:51 +08:00
Guangdong Liu
e29e8e3180 feat: enhance annotation API to support optional message_id and content fields (#27460)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-04 14:11:09 +08:00
red_sun
7a81e720d4 fix: iteration node cannot be viewed(#27759) (#27786) 2025-11-04 12:37:31 +08:00
XlKsyt
55600c0eb1 feat: add metrics logging and improve MeterProvider lifecycle for tencent APM (#27733)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-04 12:35:53 +08:00
kenwoodjw
35e41d7d68 fix: bump pyobvector to 0.2.17 (#27791)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-11-04 12:25:50 +08:00
Ponder
b610cf9a11 feat: add segments max number limit for SegmentApi.post (#27745) 2025-11-04 10:27:58 +08:00
-LAN-
c8e9edc024 refactor(api): set default value for EasyUIBasedAppGenerateEntity.query (#27712) 2025-11-04 10:22:43 +08:00
49
471cd760d7 fix: improve infinite scroll observer responsiveness (#27546)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-04 10:15:27 +08:00
墨绿色
7f48c57edf fix: datasets weight settings embedding model does not change (#27694)
Co-authored-by: lijiezhao <lijiezhao@perfect99.com>
2025-11-04 10:00:36 +08:00
NeatGuyCoding
6569801162 extract parse_time_range for console app stats related queries (#27626)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-11-04 10:00:12 +08:00
国昊
9dd83f50a7 FIX Issue #27697: Add env variable in docker-compose(template) and make it take effect. (#27704) 2025-11-04 09:58:59 +08:00
CrabSAMA
59c56b1b0d fix: File model add known extra fields, fix issue about the tool of… (#27607) 2025-11-04 09:57:25 +08:00
Tianzhi Jin
94cd2de940 fix(api): return timestamp as integer in document api (#27761) 2025-11-04 09:55:47 +08:00
heyszt
3c23375607 refactor: Use Repository Pattern for Model Layer (#27663) 2025-11-04 09:53:22 +08:00
dependabot[bot]
56047f638f chore(deps): bump dayjs from 1.11.18 to 1.11.19 in /web (#27735)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-04 09:47:57 +08:00
vicen
9c01d3e775 fix: two web bugs for json-schema-config-modal (#27718) 2025-11-04 09:45:28 +08:00
海狸大師
c85c87f3da fix(i18n/zh-Hant): unify terminology and improve translation consistency (#27717) 2025-11-04 09:42:26 +08:00
-LAN-
eaa02e3d55 Add SQLAlchemy Mapped annotations to MessageFeedback (#27768) 2025-11-04 09:39:59 +08:00
yihong
0219222a60 fix: pin litellm version ignore build issue (#27742)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-11-04 09:39:03 +08:00
yangzheli
dba659b220 fix(web): fix issues with links, Chinese translations, and styling on the logs page (#27669) 2025-11-04 09:38:15 +08:00
Bowen Liang
ee6458768e cleanup orphan packages in packages stage of api dockerfile (#27617) 2025-11-04 09:36:52 +08:00
Shemol
ed3d02dc6d web(markdown): support <think> without trailing newline in preprocessThinkTag (#27776)
Signed-off-by: SherlockShemol <shemol@163.com>
2025-11-04 09:35:54 +08:00
CrabSAMA
95471b1188 fix(ui): fixed the bug about empty placeholder when plugin install successfully (#27780) 2025-11-04 09:35:14 +08:00
aka James4u
6190cfbfd8 feat: localization for hi-IN (#27783)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-04 09:34:41 +08:00
aka James4u
11f2f95103 Added it-IT for italian (#27665) 2025-11-03 11:51:45 +08:00
-LAN-
2abbc14703 refactor: replace hardcoded user plan strings with CloudPlan enum (#27675) 2025-11-03 11:51:09 +08:00
dependabot[bot]
b2b2816ade chore(deps): bump tablestore from 6.2.0 to 6.3.7 in /api (#27736) 2025-11-03 11:50:39 +08:00
-LAN-
4461df1bd9 refactor(api): add SQLAlchemy 2.x Mapped type hints to Message model (#27709)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-01 01:16:07 +08:00
katakyo
f7f6b4a8b0 i18n(ja-JP): Use 「公開」 for App Overview “Launch” action label (#27680) 2025-10-31 11:23:38 +08:00
603 changed files with 30344 additions and 7453 deletions

View File

@@ -1,17 +1,15 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
source /home/vscode/.bashrc

View File

@@ -62,7 +62,7 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
db_postgres
redis
sandbox
ssrf_proxy

View File

@@ -8,7 +8,7 @@ concurrency:
cancel-in-progress: true
jobs:
db-migration-test:
db-migration-test-postgres:
runs-on: ubuntu-latest
steps:
@@ -45,7 +45,7 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
db_postgres
redis
- name: Prepare configs
@@ -57,3 +57,60 @@ jobs:
env:
DEBUG: true
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env for MySQL
run: |
cd docker
cp middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db_mysql
redis
- name: Prepare configs for MySQL
run: |
cd api
cp .env.example .env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
- name: Run DB Migration
env:
DEBUG: true
run: uv run --directory api flask upgrade-db

2
.gitignore vendored
View File

@@ -186,6 +186,8 @@ docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf

View File

@@ -117,7 +117,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.

View File

@@ -72,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# PostgreSQL database configuration
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30
@@ -164,7 +167,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@@ -175,6 +178,17 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
OCEANBASE_FULLTEXT_PARSER=ik
SEEKDB_MEMORY_LIMIT=2G
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
@@ -340,15 +354,6 @@ LINDORM_PASSWORD=admin
LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
@@ -374,6 +379,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Comma-separated list of file extensions blocked from upload for security reasons.
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
# Empty by default to allow all file types.
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
@@ -521,7 +532,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true
WORKFLOW_LOG_CLEANUP_ENABLED=false
# Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
@@ -620,3 +631,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0

View File

@@ -15,7 +15,11 @@ FROM base AS packages
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml uv.lock ./
@@ -49,7 +53,9 @@ RUN \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2

View File

@@ -15,8 +15,8 @@
```bash
cd ../docker
cp middleware.env.example middleware.env
# change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
cd ../api
```

View File

@@ -1,7 +1,7 @@
import sys
def is_db_command():
def is_db_command() -> bool:
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False

View File

@@ -1471,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
@click.option(
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
)
def transform_datasource_credentials(environment: str):
"""
Transform datasource credentials
"""
@@ -1482,9 +1485,14 @@ def transform_datasource_credentials():
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
if environment == "online":
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
jina_plugin_unique_identifier = None
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
@@ -1650,7 +1658,7 @@ def transform_datasource_credentials():
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
provider="jinareader",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,

View File

@@ -360,6 +360,31 @@ class FileUploadConfig(BaseSettings):
default=10,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
"Empty by default to allow all file types."
),
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
default="",
)
@computed_field # type: ignore[misc]
@property
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
"""
Parse and return the blacklist as a set of lowercase extensions.
Returns an empty set if no blacklist is configured.
"""
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
return set()
return {
ext.strip().lower().strip(".")
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
if ext.strip()
}
class HttpConfig(BaseSettings):
"""
@@ -949,6 +974,11 @@ class DataSetConfig(BaseSettings):
default=True,
)
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
description="Maximum number of segments for dataset segments API (0 for unlimited)",
default=0,
)
class WorkspaceConfig(BaseSettings):
"""
@@ -1160,7 +1190,7 @@ class AccountConfig(BaseSettings):
class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
@@ -1179,6 +1209,13 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -1205,6 +1242,7 @@ class FeatureConfig(
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,

View File

@@ -105,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
@@ -140,10 +146,10 @@ class DatabaseConfig(BaseSettings):
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
default="postgresql",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
@computed_field # type: ignore[prop-decorator]
@property
@@ -204,15 +210,15 @@ class DatabaseConfig(BaseSettings):
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
connect_args = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,

View File

@@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,

File diff suppressed because one or more lines are too long

View File

@@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class BlockedFileExtensionError(BaseHTTPException):
error_code = "file_extension_blocked"
description = "The file extension is blocked for security reasons."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."

View File

@@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect(
api.parser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@api.expect(parser)
@api.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@@ -25,13 +27,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args")
.add_argument("model_mode", type=str, required=True, location="args")
.add_argument("has_context", type=str, required=False, default="true", location="args")
.add_argument("model_name", type=str, required=True, location="args")
)
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)

View File

@@ -8,17 +8,19 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
)
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource):
@api.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
)
@api.expect(parser)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
@api.response(400, "Invalid request parameters")
@setup_required
@@ -27,12 +29,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args")
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
)
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])

View File

@@ -16,6 +16,7 @@ from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
@@ -175,8 +176,10 @@ class AnnotationApi(Resource):
api.model(
"CreateAnnotationRequest",
{
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
@@ -193,11 +196,14 @@ class AnnotationApi(Resource):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation
@setup_required
@@ -245,6 +251,13 @@ class AnnotationExportApi(Resource):
return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation")
@@ -253,6 +266,7 @@ class AnnotationUpdateDeleteApi(Resource):
@api.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@@ -262,11 +276,6 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation

View File

@@ -15,11 +15,12 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -106,6 +107,35 @@ class AppListApi(Resource):
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
workflow_capable_app_ids = [
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
]
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
)
)
.scalars()
.all()
)
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
return marshal(app_pagination, app_pagination_fields), 200
@api.doc("create_app")
@@ -353,12 +383,15 @@ class AppExportApi(Resource):
}
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@api.doc("check_app_name")
@api.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"})
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
@api.expect(parser)
@api.response(200, "Name availability checked")
@setup_required
@login_required
@@ -367,7 +400,6 @@ class AppNameApi(Resource):
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
app_service = AppService()

View File

@@ -1,6 +1,7 @@
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@@ -18,9 +19,23 @@ from services.feature_service import FeatureService
from .. import console_ns
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@@ -30,18 +45,6 @@ class AppImportApi(Resource):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session

View File

@@ -1,7 +1,5 @@
from datetime import datetime
import pytz
import sqlalchemy as sa
from flask import abort
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
@@ -19,7 +17,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
@@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)

View File

@@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
@@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
@@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@cloud_edition_billing_resource_check("annotation")
@account_initialization_required
@edit_permission_required
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count")

View File

@@ -1,9 +1,7 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@@ -11,9 +9,10 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString, convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
from models import AppMode
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
@@ -45,8 +44,9 @@ class DailyMessageStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(*) AS message_count
FROM
messages
@@ -56,26 +56,16 @@ WHERE
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -91,16 +81,19 @@ WHERE
return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily conversation statistics retrieved successfully",
@@ -113,48 +106,40 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
stmt = (
sa.select(
sa.func.date(
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
).label("date"),
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at >= start_datetime_utc)
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at < end_datetime_utc)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
stmt = stmt.group_by("date").order_by("date")
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(stmt, {"tz": account.timezone})
for row in rs:
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
return jsonify({"data": response_data})
@@ -164,11 +149,7 @@ class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily terminal statistics retrieved successfully",
@@ -181,15 +162,11 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM
messages
@@ -198,26 +175,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -238,11 +206,7 @@ class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily token cost statistics retrieved successfully",
@@ -255,15 +219,11 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price
FROM
@@ -273,26 +233,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -315,11 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Average session interaction statistics retrieved successfully",
@@ -332,15 +279,11 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(subquery.message_count) AS interactions
FROM
(
@@ -357,26 +300,17 @@ FROM
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -408,11 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"User satisfaction rate statistics retrieved successfully",
@@ -425,15 +355,11 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count
FROM
@@ -446,26 +372,17 @@ WHERE
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -491,11 +408,7 @@ class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Average response time statistics retrieved successfully",
@@ -508,15 +421,11 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(provider_response_latency) AS latency
FROM
messages
@@ -525,26 +434,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -565,11 +465,7 @@ class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Tokens per second statistics retrieved successfully",
@@ -581,16 +477,11 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
@@ -602,26 +493,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc

View File

@@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
@@ -112,7 +113,18 @@ class DraftWorkflowApi(Resource):
},
)
)
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
@@ -574,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource):
@api.doc("get_published_workflow")
@@ -598,6 +617,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
@api.expect(parser_publish)
@setup_required
@login_required
@account_initialization_required
@@ -608,12 +628,8 @@ class PublishedWorkflowApi(Resource):
Publish workflow
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
@@ -668,6 +684,9 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource):
@api.doc("get_default_block_config")
@@ -675,6 +694,7 @@ class DefaultBlockConfigApi(Resource):
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@api.response(200, "Default block configuration retrieved successfully")
@api.response(404, "Block type not found")
@api.expect(parser_block)
@setup_required
@login_required
@account_initialization_required
@@ -684,8 +704,7 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
args = parser_block.parse_args()
q = args.get("q")
@@ -701,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource):
@api.expect(parser_convert)
@api.doc("convert_to_workflow")
@api.doc(description="Convert application to workflow mode")
@api.doc(params={"app_id": "Application ID"})
@@ -723,14 +752,7 @@ class ConvertToWorkflowApi(Resource):
current_user, _ = current_account_with_tenant()
if request.data:
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_convert.parse_args()
else:
args = {}
@@ -744,8 +766,18 @@ class ConvertToWorkflowApi(Resource):
}
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.expect(parser_workflows)
@api.doc("get_all_published_workflows")
@api.doc(description="Get all published workflows for an application")
@api.doc(params={"app_id": "Application ID"})
@@ -762,16 +794,9 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
args = parser_workflows.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
@@ -979,11 +1004,13 @@ class DraftWorkflowTriggerRunApi(Resource):
event = poller.poll()
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=event.workflow_args,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
@@ -992,7 +1019,7 @@ class DraftWorkflowTriggerRunApi(Resource):
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except PluginInvokeError as e:
raise ValueError(e.to_user_friendly_error())
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@@ -1050,7 +1077,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
)
event = poller.poll()
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 500
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@@ -1074,7 +1101,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{"status": "error", "error": "An unexpected error occurred while running the node."}
), 500
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
@@ -1126,7 +1153,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
node_ids=node_ids,
)
except PluginInvokeError as e:
raise ValueError(e.to_user_friendly_error())
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@@ -1134,10 +1161,12 @@ class DraftWorkflowTriggerRunAllApi(Resource):
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
try:
workflow_args = dict(trigger_debug_event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=trigger_debug_event.workflow_args,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=trigger_debug_event.node_id,
@@ -1151,4 +1180,4 @@ class DraftWorkflowTriggerRunAllApi(Resource):
{
"status": "error",
}
), 500
), 400

View File

@@ -30,23 +30,25 @@ def _parse_workflow_run_list_args():
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()
@@ -58,28 +60,30 @@ def _parse_workflow_run_count_args():
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()

View File

@@ -1,7 +1,4 @@
from datetime import datetime
import pytz
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
@@ -9,6 +6,7 @@ from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
@@ -43,23 +41,11 @@ class WorkflowDailyRunsStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_daily_runs_statistics(
tenant_id=app_model.tenant_id,
@@ -100,23 +86,11 @@ class WorkflowDailyTerminalsStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
tenant_id=app_model.tenant_id,
@@ -157,23 +131,11 @@ class WorkflowDailyTokenCostStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
tenant_id=app_model.tenant_id,
@@ -214,23 +176,11 @@ class WorkflowAverageAppInteractionStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
tenant_id=app_model.tenant_id,

View File

@@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -16,7 +17,13 @@ class Subscription(Resource):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()

View File

@@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -121,8 +121,16 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource):
@api.expect(parser_datasource)
@setup_required
@login_required
@account_initialization_required
@@ -130,14 +138,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_datasource.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@@ -168,8 +169,14 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource):
@api.expect(parser_datasource_delete)
@setup_required
@login_required
@account_initialization_required
@@ -181,10 +188,7 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_datasource_delete.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
@@ -195,8 +199,17 @@ class DatasourceAuthDeleteApi(Resource):
return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource):
@api.expect(parser_datasource_update)
@setup_required
@login_required
@account_initialization_required
@@ -205,13 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_datasource_update.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
@@ -251,8 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource):
@api.expect(parser_datasource_custom)
@setup_required
@login_required
@account_initialization_required
@@ -260,12 +275,7 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_datasource_custom.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
@@ -291,8 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource):
@api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@@ -300,8 +314,7 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = parser_default.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
@@ -312,8 +325,16 @@ class DatasourceAuthDefaultApi(Resource):
return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource):
@api.expect(parser_update_name)
@setup_required
@login_required
@account_initialization_required
@@ -321,12 +342,7 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_update_name.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
@@ -148,8 +148,12 @@ class DraftRagPipelineApi(Resource):
}
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource):
@api.expect(parser_run)
@setup_required
@login_required
@account_initialization_required
@@ -162,8 +166,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = parser_run.parse_args()
try:
response = PipelineGenerateService.generate_single_iteration(
@@ -184,6 +187,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource):
@api.expect(parser_run)
@setup_required
@login_required
@account_initialization_required
@@ -197,8 +201,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = parser_run.parse_args()
try:
response = PipelineGenerateService.generate_single_loop(
@@ -217,8 +220,18 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
@api.expect(parser_draft_run)
@setup_required
@login_required
@account_initialization_required
@@ -232,14 +245,7 @@ class DraftRagPipelineRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_draft_run.parse_args()
try:
response = PipelineGenerateService.generate(
@@ -255,8 +261,21 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
@api.expect(parser_published_run)
@setup_required
@login_required
@account_initialization_required
@@ -270,17 +289,7 @@ class PublishedRagPipelineRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_published_run.parse_args()
streaming = args["response_mode"] == "streaming"
@@ -381,8 +390,17 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required
@login_required
@account_initialization_required
@@ -396,13 +414,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
@@ -429,6 +441,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required
@login_required
@account_initialization_required
@@ -442,13 +455,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
@@ -473,8 +480,14 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
)
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource):
@api.expect(parser_run_api)
@setup_required
@login_required
@account_initialization_required
@@ -489,10 +502,7 @@ class RagPipelineDraftNodeRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_run_api.parse_args()
inputs = args.get("inputs")
if inputs == None:
@@ -607,8 +617,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource):
@api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@@ -622,8 +636,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
args = parser_default.parse_args()
q = args.get("q")
@@ -639,8 +652,18 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource):
@api.expect(parser_wf)
@setup_required
@login_required
@account_initialization_required
@@ -654,16 +677,9 @@ class PublishedAllRagPipelineApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
args = parser_wf.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
@@ -691,8 +707,16 @@ class PublishedAllRagPipelineApi(Resource):
}
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@api.expect(parser_wf_id)
@setup_required
@login_required
@account_initialization_required
@@ -707,19 +731,13 @@ class RagPipelineByIdApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_wf_id.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data
update_data = {}
@@ -752,8 +770,12 @@ class RagPipelineByIdApi(Resource):
return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -763,8 +785,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@@ -777,6 +798,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -786,8 +808,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@@ -800,6 +821,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -809,8 +831,7 @@ class DraftRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@@ -823,6 +844,7 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@@ -832,8 +854,7 @@ class DraftRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@@ -845,8 +866,16 @@ class DraftRagPipelineSecondStepApi(Resource):
}
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
@api.expect(parser_wf_run)
@setup_required
@login_required
@account_initialization_required
@@ -856,12 +885,7 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = parser_wf_run.parse_args()
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@@ -961,8 +985,18 @@ class RagPipelineTransformApi(Resource):
return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
@api.expect(parser_var)
@setup_required
@login_required
@account_initialization_required
@@ -974,14 +1008,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_var.parse_args()
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@@ -35,15 +35,18 @@ recommended_app_list_fields = {
}
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@api.expect(parser_apps)
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
args = parser.parse_args()
args = parser_apps.parse_args()
language = args.get("language")
if language and language in languages:

View File

@@ -8,6 +8,7 @@ import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
@@ -83,6 +84,8 @@ class FileApi(Resource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
return upload_file, 201

View File

@@ -10,6 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.console import api
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@@ -36,12 +37,15 @@ class RemoteFileInfoApi(Resource):
}
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@api.expect(parser_upload)
@marshal_with(file_fields_with_signed_url)
def post(self):
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
args = parser_upload.parse_args()
url = args["url"]

View File

@@ -49,6 +49,7 @@ class SetupApi(Resource):
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)

View File

@@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
@@ -16,6 +16,19 @@ def _validate_name(name):
return name
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tags")
class TagListApi(Resource):
@setup_required
@@ -30,6 +43,7 @@ class TagListApi(Resource):
return tags, 200
@api.expect(parser_tags)
@setup_required
@login_required
@account_initialization_required
@@ -39,20 +53,7 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@@ -60,8 +61,14 @@ class TagListApi(Resource):
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@api.expect(parser_tag_id)
@setup_required
@login_required
@account_initialization_required
@@ -72,10 +79,7 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
args = parser.parse_args()
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -99,8 +103,17 @@ class TagUpdateDeleteApi(Resource):
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@api.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@@ -110,26 +123,23 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_create.parse_args()
TagService.save_tag_binding(args)
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@api.expect(parser_remove)
@setup_required
@login_required
@account_initialization_required
@@ -139,15 +149,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
return {"result": "success"}, 200

View File

@@ -11,16 +11,16 @@ from . import api, console_ns
logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
@console_ns.route("/version")
class VersionApi(Resource):
@api.doc("check_version_update")
@api.doc(description="Check for application version updates")
@api.expect(
api.parser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
)
@api.expect(parser)
@api.response(
200,
"Success",
@@ -37,7 +37,6 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@@ -43,8 +43,19 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
def _init_parser():
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
return parser
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@api.expect(_init_parser())
@setup_required
@login_required
def post(self):
@@ -53,14 +64,7 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
args = parser.parse_args()
args = _init_parser().parse_args()
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
@@ -106,16 +110,19 @@ class AccountProfileApi(Resource):
return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@api.expect(parser_name)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_name.parse_args()
# Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
@@ -126,68 +133,80 @@ class AccountNameApi(Resource):
return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@api.expect(parser_avatar)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_avatar.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@api.expect(parser_interface)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
args = parser.parse_args()
args = parser_interface.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@api.expect(parser_theme)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
args = parser.parse_args()
args = parser_theme.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@api.expect(parser_timezone)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_timezone.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
@@ -198,21 +217,24 @@ class AccountTimezoneApi(Resource):
return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@api.expect(parser_pw)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_pw.parse_args()
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
@@ -294,20 +316,23 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@api.expect(parser_delete)
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_delete.parse_args()
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
raise InvalidAccountDeletionCodeError()
@@ -317,16 +342,19 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@api.expect(parser_feedback)
@setup_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_feedback.parse_args()
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
@@ -351,6 +379,14 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@@ -360,6 +396,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
@api.expect(parser_edu)
@setup_required
@login_required
@account_initialization_required
@@ -368,13 +405,7 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_edu.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@@ -394,6 +425,14 @@ class EducationApi(Resource):
return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@@ -402,6 +441,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
@api.expect(parser_autocomplete)
@setup_required
@login_required
@account_initialization_required
@@ -409,33 +449,30 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
args = parser.parse_args()
args = parser_autocomplete.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@api.expect(parser_change_email)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_change_email.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -470,20 +507,23 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@api.expect(parser_validity)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_validity.parse_args()
user_email = args["email"]
@@ -514,20 +554,23 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@api.expect(parser_reset)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_reset.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError()
@@ -555,12 +598,15 @@ class ChangeEmailResetApi(Resource):
return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@api.expect(parser_check)
@setup_required
def post(self):
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
args = parser.parse_args()
args = parser_check.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["email"]):

View File

@@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@@ -48,22 +48,25 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_invite = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@api.expect(parser_invite)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_invite.parse_args()
invitee_emails = args["emails"]
invitee_role = args["role"]
@@ -143,16 +146,19 @@ class MemberCancelInviteApi(Resource):
}, 200
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@api.expect(parser_update)
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_update.parse_args()
new_role = args["role"]
if not TenantAccountRole.is_valid_role(new_role):
@@ -191,17 +197,20 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@api.expect(parser_send)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = parser_send.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -229,19 +238,22 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
parser_owner = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@api.expect(parser_owner)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_owner.parse_args()
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@@ -276,17 +288,20 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_owner_transfer = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@api.expect(parser_owner_transfer)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
parser = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_owner_transfer.parse_args()
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()

View File

@@ -4,7 +4,7 @@ from flask import send_file
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -14,9 +14,19 @@ from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
parser_model = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@api.expect(parser_model)
@setup_required
@login_required
@account_initialization_required
@@ -24,15 +34,7 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
args = parser_model.parse_args()
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
@@ -40,8 +42,30 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list})
parser_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@api.expect(parser_cred)
@setup_required
@login_required
@account_initialization_required
@@ -49,10 +73,7 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
args = parser.parse_args()
args = parser_cred.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
@@ -61,6 +82,7 @@ class ModelProviderCredentialApi(Resource):
return {"credentials": credentials}
@api.expect(parser_post_cred)
@setup_required
@login_required
@account_initialization_required
@@ -69,12 +91,7 @@ class ModelProviderCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService()
@@ -90,6 +107,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 201
@api.expect(parser_put_cred)
@setup_required
@login_required
@account_initialization_required
@@ -98,13 +116,7 @@ class ModelProviderCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService()
@@ -121,6 +133,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}
@api.expect(parser_delete_cred)
@setup_required
@login_required
@account_initialization_required
@@ -128,10 +141,8 @@ class ModelProviderCredentialApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
@@ -141,8 +152,14 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204
parser_switch = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
@@ -150,10 +167,7 @@ class ModelProviderCredentialSwitchApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_switch.parse_args()
service = ModelProviderService()
service.switch_active_provider_credential(
@@ -164,17 +178,20 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"}
parser_validate = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@api.expect(parser_validate)
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_validate.parse_args()
tenant_id = current_tenant_id
@@ -218,8 +235,19 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
parser_preferred = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@api.expect(parser_preferred)
@setup_required
@login_required
@account_initialization_required
@@ -230,15 +258,7 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id
parser = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
args = parser.parse_args()
args = parser_preferred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(

View File

@@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -16,23 +16,29 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
parser_get_default = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
parser_post_default = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@api.expect(parser_get_default)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
args = parser_get_default.parse_args()
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
@@ -41,6 +47,7 @@ class DefaultModelApi(Resource):
return jsonable_encoder({"data": default_model_entity})
@api.expect(parser_post_default)
@setup_required
@login_required
@account_initialization_required
@@ -50,10 +57,7 @@ class DefaultModelApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_post_default.parse_args()
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@@ -84,6 +88,35 @@ class DefaultModelApi(Resource):
return {"result": "success"}
parser_post_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser_delete_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@@ -97,6 +130,7 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
@api.expect(parser_post_models)
@setup_required
@login_required
@account_initialization_required
@@ -106,23 +140,7 @@ class ModelProviderModelApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_post_models.parse_args()
if args.get("config_from", "") == "custom-model":
if not args.get("credential_id"):
@@ -160,6 +178,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@api.expect(parser_delete_models)
@setup_required
@login_required
@account_initialization_required
@@ -169,19 +188,7 @@ class ModelProviderModelApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_delete_models.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
@@ -191,29 +198,76 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204
parser_get_credentials = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@api.expect(parser_get_credentials)
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
args = parser_get_credentials.parse_args()
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
@@ -257,6 +311,7 @@ class ModelProviderModelCredentialApi(Resource):
}
)
@api.expect(parser_post_cred)
@setup_required
@login_required
@account_initialization_required
@@ -266,21 +321,7 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService()
@@ -304,6 +345,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 201
@api.expect(parser_put_cred)
@setup_required
@login_required
@account_initialization_required
@@ -313,22 +355,7 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService()
@@ -347,6 +374,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}
@api.expect(parser_delete_cred)
@setup_required
@login_required
@account_initialization_required
@@ -355,20 +383,7 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
@@ -382,8 +397,24 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204
parser_switch = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
@@ -392,20 +423,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_switch.parse_args()
service = ModelProviderService()
service.add_model_credential_to_model_list(
@@ -418,29 +436,32 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"}
parser_model_enable_disable = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@api.expect(parser_model_enable_disable)
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
@@ -454,25 +475,14 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@api.expect(parser_model_enable_disable)
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
@@ -482,28 +492,31 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"}
parser_validate = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@api.expect(parser_validate)
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_validate.parse_args()
model_provider_service = ModelProviderService()
@@ -530,16 +543,19 @@ class ModelProviderModelValidateApi(Resource):
return response
parser_parameter = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@api.expect(parser_parameter)
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_parameter.parse_args()
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()

View File

@@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -37,19 +37,22 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
parser_list = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@api.expect(parser_list)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
args = parser.parse_args()
args = parser_list.parse_args()
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
except PluginDaemonClientSideError as e:
@@ -58,14 +61,17 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@api.expect(parser_latest)
@setup_required
@login_required
@account_initialization_required
def post(self):
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args()
args = parser_latest.parse_args()
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
@@ -75,16 +81,19 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions})
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@api.expect(parser_ids)
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
args = parser_ids.parse_args()
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
@@ -94,16 +103,19 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins})
parser_icon = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@api.expect(parser_icon)
@setup_required
def get(self):
req = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
args = req.parse_args()
args = parser_icon.parse_args()
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
@@ -157,8 +169,17 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
parser_github = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@api.expect(parser_github)
@setup_required
@login_required
@account_initialization_required
@@ -166,13 +187,7 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_github.parse_args()
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
@@ -206,19 +221,21 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
parser_pkg = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@api.expect(parser_pkg)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
args = parser_pkg.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
@@ -233,8 +250,18 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response)
parser_githubapi = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@api.expect(parser_githubapi)
@setup_required
@login_required
@account_initialization_required
@@ -242,14 +269,7 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_githubapi.parse_args()
try:
response = PluginService.install_from_github(
@@ -265,8 +285,14 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
parser_marketplace = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@api.expect(parser_marketplace)
@setup_required
@login_required
@account_initialization_required
@@ -274,10 +300,7 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
args = parser_marketplace.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
@@ -292,19 +315,21 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response)
parser_pkgapi = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@api.expect(parser_pkgapi)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
args = parser_pkgapi.parse_args()
try:
return jsonable_encoder(
@@ -319,8 +344,14 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
parser_fetch = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@api.expect(parser_fetch)
@setup_required
@login_required
@account_initialization_required
@@ -328,10 +359,7 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
args = parser_fetch.parse_args()
try:
return jsonable_encoder(
@@ -345,8 +373,16 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e)
parser_tasks = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@api.expect(parser_tasks)
@setup_required
@login_required
@account_initialization_required
@@ -354,12 +390,7 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
args = parser_tasks.parse_args()
try:
return jsonable_encoder(
@@ -429,8 +460,16 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
parser_marketplace_api = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@api.expect(parser_marketplace_api)
@setup_required
@login_required
@account_initialization_required
@@ -438,12 +477,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_marketplace_api.parse_args()
try:
return jsonable_encoder(
@@ -455,8 +489,19 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e)
parser_github_post = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@api.expect(parser_github_post)
@setup_required
@login_required
@account_initialization_required
@@ -464,15 +509,7 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_github_post.parse_args()
try:
return jsonable_encoder(
@@ -489,15 +526,20 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e)
parser_uninstall = reqparse.RequestParser().add_argument(
"plugin_installation_id", type=str, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@api.expect(parser_uninstall)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
args = parser_uninstall.parse_args()
_, tenant_id = current_account_with_tenant()
@@ -507,8 +549,16 @@ class PluginUninstallApi(Resource):
raise ValueError(e)
parser_change_post = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@api.expect(parser_change_post)
@setup_required
@login_required
@account_initialization_required
@@ -518,12 +568,7 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
args = req.parse_args()
args = parser_change_post.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
@@ -558,8 +603,20 @@ class PluginFetchPermissionApi(Resource):
)
parser_dynamic = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@api.expect(parser_dynamic)
@setup_required
@login_required
@account_initialization_required
@@ -571,16 +628,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
user_id = current_user.id
parser = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
args = parser.parse_args()
args = parser_dynamic.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
@@ -599,8 +647,16 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
parser_change = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@api.expect(parser_change)
@setup_required
@login_required
@account_initialization_required
@@ -609,12 +665,7 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
args = req.parse_args()
args = parser_change.parse_args()
permission = args["permission"]
@@ -694,8 +745,12 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@api.expect(parser_exclude)
@setup_required
@login_required
@account_initialization_required
@@ -703,8 +758,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()
args = parser_exclude.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
@@ -52,8 +52,19 @@ def is_valid_url(url: str) -> bool:
return False
parser_tool = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
)
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
@api.expect(parser_tool)
@setup_required
@login_required
@account_initialization_required
@@ -62,15 +73,7 @@ class ToolProviderListApi(Resource):
user_id = user.id
req = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
)
args = req.parse_args()
args = parser_tool.parse_args()
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@@ -102,8 +105,14 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
parser_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource):
@api.expect(parser_delete)
@setup_required
@login_required
@account_initialization_required
@@ -112,10 +121,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = req.parse_args()
args = parser_delete.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
tenant_id,
@@ -124,8 +130,17 @@ class ToolBuiltinProviderDeleteApi(Resource):
)
parser_add = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource):
@api.expect(parser_add)
@setup_required
@login_required
@account_initialization_required
@@ -134,13 +149,7 @@ class ToolBuiltinProviderAddApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_add.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
@@ -155,8 +164,17 @@ class ToolBuiltinProviderAddApi(Resource):
)
parser_update = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource):
@api.expect(parser_update)
@setup_required
@login_required
@account_initialization_required
@@ -168,14 +186,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_update.parse_args()
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
@@ -213,8 +224,22 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
parser_api_add = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
@api.expect(parser_api_add)
@setup_required
@login_required
@account_initialization_required
@@ -226,19 +251,7 @@ class ToolApiProviderAddApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api_add.parse_args()
return ApiToolManageService.create_api_tool_provider(
user_id,
@@ -254,8 +267,12 @@ class ToolApiProviderAddApi(Resource):
)
parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
@api.expect(parser_remote)
@setup_required
@login_required
@account_initialization_required
@@ -264,9 +281,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
args = parser_remote.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema(
user_id,
@@ -275,8 +290,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
)
parser_tools = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
@api.expect(parser_tools)
@setup_required
@login_required
@account_initialization_required
@@ -285,11 +306,7 @@ class ToolApiProviderListToolsApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_tools.parse_args()
return jsonable_encoder(
ApiToolManageService.list_api_tool_provider_tools(
@@ -300,8 +317,23 @@ class ToolApiProviderListToolsApi(Resource):
)
parser_api_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
@api.expect(parser_api_update)
@setup_required
@login_required
@account_initialization_required
@@ -313,20 +345,7 @@ class ToolApiProviderUpdateApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api_update.parse_args()
return ApiToolManageService.update_api_tool_provider(
user_id,
@@ -343,8 +362,14 @@ class ToolApiProviderUpdateApi(Resource):
)
parser_api_delete = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
@api.expect(parser_api_delete)
@setup_required
@login_required
@account_initialization_required
@@ -356,11 +381,7 @@ class ToolApiProviderDeleteApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_api_delete.parse_args()
return ApiToolManageService.delete_api_tool_provider(
user_id,
@@ -369,8 +390,12 @@ class ToolApiProviderDeleteApi(Resource):
)
parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args")
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
@api.expect(parser_get)
@setup_required
@login_required
@account_initialization_required
@@ -379,11 +404,7 @@ class ToolApiProviderGetApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_get.parse_args()
return ApiToolManageService.get_api_tool_provider(
user_id,
@@ -407,40 +428,44 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
)
parser_schema = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
@api.expect(parser_schema)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_schema.parse_args()
return ApiToolManageService.parser_api_schema(
schema=args["schema"],
)
parser_pre = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
@api.expect(parser_pre)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_pre.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_tenant_id,
@@ -453,8 +478,22 @@ class ToolApiProviderPreviousTestApi(Resource):
)
parser_create = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
@api.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@@ -466,19 +505,7 @@ class ToolWorkflowProviderCreateApi(Resource):
user_id = user.id
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
args = parser_create.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id=user_id,
@@ -494,8 +521,22 @@ class ToolWorkflowProviderCreateApi(Resource):
)
parser_workflow_update = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
@api.expect(parser_workflow_update)
@setup_required
@login_required
@account_initialization_required
@@ -507,19 +548,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
user_id = user.id
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
args = parser_workflow_update.parse_args()
if not args["workflow_tool_id"]:
raise ValueError("incorrect workflow_tool_id")
@@ -538,8 +567,14 @@ class ToolWorkflowProviderUpdateApi(Resource):
)
parser_workflow_delete = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
@api.expect(parser_workflow_delete)
@setup_required
@login_required
@account_initialization_required
@@ -551,11 +586,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
user_id = user.id
reqparser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = reqparser.parse_args()
args = parser_workflow_delete.parse_args()
return WorkflowToolManageService.delete_workflow_tool(
user_id,
@@ -564,8 +595,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
)
parser_wf_get = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
@api.expect(parser_wf_get)
@setup_required
@login_required
@account_initialization_required
@@ -574,13 +613,7 @@ class ToolWorkflowProviderGetApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
args = parser_wf_get.parse_args()
if args.get("workflow_tool_id"):
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
@@ -600,8 +633,14 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool)
parser_wf_tools = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
@api.expect(parser_wf_tools)
@setup_required
@login_required
@account_initialization_required
@@ -610,11 +649,7 @@ class ToolWorkflowProviderListToolApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_wf_tools.parse_args()
return jsonable_encoder(
WorkflowToolManageService.list_single_workflow_tools(
@@ -790,32 +825,40 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_default_cred = reqparse.RequestParser().add_argument(
"id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
@api.expect(parser_default_cred)
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = parser_default_cred.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
parser_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource):
@api.expect(parser_custom)
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_custom.parse_args()
user, tenant_id = current_account_with_tenant()
@@ -878,25 +921,44 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
)
parser_mcp = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
parser_mcp_put = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
parser_mcp_delete = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
@api.expect(parser_mcp)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
args = parser_mcp.parse_args()
user, tenant_id = current_account_with_tenant()
# Parse and validate models
@@ -921,24 +983,12 @@ class ToolProviderMCPApi(Resource):
)
return jsonable_encoder(result)
@api.expect(parser_mcp_put)
@setup_required
@login_required
@account_initialization_required
def put(self):
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
args = parser_mcp_put.parse_args()
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
@@ -972,14 +1022,12 @@ class ToolProviderMCPApi(Resource):
)
return {"result": "success"}
@api.expect(parser_mcp_delete)
@setup_required
@login_required
@account_initialization_required
def delete(self):
parser = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_mcp_delete.parse_args()
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
@@ -988,18 +1036,21 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"}
parser_auth = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
@api.expect(parser_auth)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_auth.parse_args()
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
@@ -1097,15 +1148,18 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools)
parser_cb = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
@api.expect(parser_cb)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
args = parser.parse_args()
args = parser_cb.parse_args()
state_key = args["state"]
authorization_code = args["code"]

View File

@@ -13,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
@@ -21,6 +21,7 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
setup_required,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
@@ -83,7 +84,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
@@ -149,15 +150,18 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_switch.parse_args()
# check if tenant_id is valid, 403 if not
try:
@@ -241,16 +245,19 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@api.expect(parser_info)
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_info.parse_args()
if not current_tenant_id:
raise ValueError("No current tenant")

View File

@@ -10,6 +10,7 @@ from flask import abort, request
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
@@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
abort(
403,
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",

View File

@@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@@ -2,6 +2,7 @@ from flask import request
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)

View File

@@ -13,6 +13,7 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -67,6 +68,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
kwargs["app_model"] = app_model
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
@@ -75,7 +77,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
@@ -90,6 +91,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)
@@ -139,7 +162,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)

View File

@@ -37,8 +37,7 @@ def trigger_endpoint(endpoint_id: str):
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e:
logger.exception("Endpoint processing failed for {endpoint_id}: {e}")
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 500
except Exception as e:
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 400
except Exception:
logger.exception("Webhook processing failed for {endpoint_id}")
return jsonify({"error": "Internal server error"}), 500

View File

@@ -88,12 +88,6 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
text_to_audio_response_fields = {
"audio_url": fields.String,
"duration": fields.Float,
}
@marshal_with(text_to_audio_response_fields)
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(

View File

@@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)
@@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)

View File

@@ -79,7 +79,7 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: str | None = None,
query: str = "",
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
@@ -105,7 +105,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query or "",
query=query,
files=files,
context=context,
memory=memory,

View File

@@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,

View File

@@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# return batch, dataset, documents
return {
"batch": batch,

View File

@@ -39,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@overload
def generate(
self,
@@ -139,8 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
# start node get inputs
# for trigger debug run, not prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,

View File

@@ -44,6 +44,9 @@ class InvokeFrom(StrEnum):
DEBUGGER = "debugger"
PUBLISHED = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str):
"""
@@ -110,6 +113,11 @@ class AppGenerateEntity(BaseModel):
inputs: Mapping[str, Any]
files: Sequence[File]
# Unique identifier of the user initiating the execution.
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
#
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
user_id: str
# extras
@@ -135,7 +143,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: str | None = None
query: str = ""
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@@ -1,15 +1,64 @@
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from models.model import AppMode
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
# Wrapper types for `WorkflowAppGenerateEntity` and
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
# and correct reconstruction of the entity field during (de)serialization.
class _WorkflowGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
entity: WorkflowAppGenerateEntity
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]
class WorkflowResumptionContext(BaseModel):
"""WorkflowResumptionContext captures all state necessary for resumption."""
version: Literal["1"] = "1"
# Only workflow / chatflow could be paused.
generate_entity: _GenerateEntityUnion
serialized_graph_runtime_state: str
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, value: str) -> Self:
return cls.model_validate_json(value)
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
return self.generate_entity.entity
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(self, session_factory: Engine | sessionmaker[Session], state_owner_user_id: str):
def __init__(
self,
session_factory: Engine | sessionmaker[Session],
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
@@ -19,6 +68,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@@ -49,13 +99,25 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
else:
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=self.graph_runtime_state.dumps(),
state=state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)

View File

@@ -140,7 +140,27 @@ class MessageCycleManager:
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
# Add new unique resources from the event
for resource in event.retriever_resources or []:
if not resource:
continue
is_duplicate = (
resource.dataset_id
and resource.document_id
and (resource.dataset_id, resource.document_id) in existing_ids
)
if not is_duplicate:
merged_resources.append(resource)
for i, resource in enumerate(merged_resources, 1):
resource.position = i
self._task_state.metadata.retriever_resources = merged_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""

View File

@@ -0,0 +1,15 @@
from collections.abc import Sequence
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: Sequence[str]

View File

@@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
# Return composite sort key: (model_type value, model position index)
return (model.model_type.value, position_index)
# Deduplicate
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
# Sort using the composite sort key
return sorted(provider_models, key=get_sort_key)

View File

@@ -74,6 +74,10 @@ class File(BaseModel):
storage_key: str | None = None,
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
url: str | None = None,
# Legacy compatibility fields - explicitly handle known extra fields
tool_file_id: str | None = None,
upload_file_id: str | None = None,
datasource_file_id: str | None = None,
):
super().__init__(
id=id,

View File

@@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
@@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
"""
)
""")
return runner_script

View File

@@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
import json
from base64 import b64decode

View File

@@ -138,6 +138,10 @@ class StreamableHTTPTransport:
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
# ping event send by server will be recognized as a message event with empty data by httpx-sse's SSEDecoder
if not sse.data.strip():
return False
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug("SSE message: %s", message)

View File

@@ -52,7 +52,7 @@ class OpenAIModeration(Moderation):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
)
openai_moderation = model_instance.invoke_moderation(text=text)

View File

@@ -1,21 +1,22 @@
import hashlib
import json
import logging
import os
import traceback
from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
@@ -30,9 +31,10 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@@ -99,22 +101,45 @@ def datetime_to_nanos(dt: datetime | None) -> int:
return int(dt.timestamp() * 1_000_000_000)
def string_to_trace_id128(string: str | None) -> int:
"""
Convert any input string into a stable 128-bit integer trace ID.
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
string_stacktrace = "".join(traceback.format_exception(error))
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
else:
error_message = str(error)
return error_message
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
It's suitable for generating consistent, unique identifiers from strings.
"""
if string is None:
string = ""
hash_object = hashlib.sha256(string.encode())
# Take the first 16 bytes (128 bits) of the hash digest
digest = hash_object.digest()[:16]
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
# Convert to a 128-bit integer
return int.from_bytes(digest, byteorder="big")
if isinstance(error, Exception):
current_span.record_exception(error)
else:
exception_type = error.__class__.__name__
exception_message = str(error)
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
current_span.set_status(Status(StatusCode.OK))
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
return json.dumps(obj, default=str, ensure_ascii=False)
class ArizePhoenixDataTrace(BaseTraceInstance):
@@ -131,9 +156,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@@ -151,7 +179,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
@@ -166,15 +194,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
workflow_metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
@@ -186,31 +208,58 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
# Through workflow_run_id, get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
try:
# Process workflow nodes
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
for node_execution in workflow_node_executions:
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
inputs_value = node_execution.inputs or {}
outputs_value = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
process_data = node_execution.process_data or {}
execution_metadata = node_execution.metadata or {}
node_metadata = {str(k): v for k, v in execution_metadata.items()}
node_metadata = {
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": node_execution.tenant_id,
"app_id": node_execution.app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
}
if node_execution.execution_metadata:
node_metadata.update(json.loads(node_execution.execution_metadata))
node_metadata.update(
{
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN
@@ -223,8 +272,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model:
node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
@@ -236,17 +286,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
else:
span_kind = OpenInferenceSpanKindValues.CHAIN
workflow_span_context = set_span_in_context(workflow_span)
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=workflow_span_context,
)
try:
@@ -260,11 +313,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
outputs = (
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
)
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
if usage_data:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
@@ -275,8 +325,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
if node_execution.status == "failed":
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
@@ -322,34 +380,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
message_span_id = RandomIdGenerator().generate_span_id()
span_context = SpanContext(
trace_id=trace_id,
span_id=message_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
context=root_span_context,
)
try:
if trace_info.error:
message_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
@@ -383,26 +425,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
message_span_context = set_span_in_context(message_span)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
context=message_span_context,
)
try:
if trace_info.error:
llm_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
if trace_info.message_data.error:
set_span_status(llm_span, trace_info.message_data.error)
else:
set_span_status(llm_span)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
if trace_info.error:
set_span_status(message_span, trace_info.error)
else:
set_span_status(message_span)
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
@@ -418,15 +460,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
@@ -445,19 +481,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
@@ -480,15 +511,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
@@ -499,19 +524,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(end_time))
@@ -533,15 +553,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
@@ -554,19 +568,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(end_time))
@@ -580,20 +589,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
trace_id = string_to_trace_id128(trace_info.message_id)
tool_span_id = RandomIdGenerator().generate_span_id()
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
# Create span context with the same trace_id as the parent
# todo: Create with the appropriate parent span context, so that the tool span is
# a child of the appropriate span (e.g. message span)
span_context = SpanContext(
trace_id=trace_id,
span_id=tool_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
@@ -612,19 +610,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
context=root_span_context,
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
@@ -641,15 +634,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
@@ -663,22 +650,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def ensure_root_span(self, dify_trace_id: str | None):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
@@ -698,26 +697,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}

View File

@@ -5,6 +5,7 @@ Tencent APM Trace Client - handles network operations, metrics, and API communic
from __future__ import annotations
import importlib
import json
import logging
import os
import socket
@@ -110,6 +111,7 @@ class TencentTraceClient:
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.meter_provider: MeterProvider | None = None
self.hist_llm_duration: Histogram | None = None
self.hist_token_usage: Histogram | None = None
self.hist_time_to_first_token: Histogram | None = None
@@ -119,7 +121,6 @@ class TencentTraceClient:
# Metrics exporter and instruments
try:
from opentelemetry import metrics
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
@@ -202,9 +203,11 @@ class TencentTraceClient:
)
if metric_reader is not None:
# Use instance-level MeterProvider instead of global to support config changes
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
metrics.set_meter_provider(provider)
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
self.meter_provider = provider
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
# LLM operation duration histogram
self.hist_llm_duration = self.meter.create_histogram(
@@ -244,6 +247,7 @@ class TencentTraceClient:
self.metric_reader = metric_reader
else:
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
@@ -253,6 +257,7 @@ class TencentTraceClient:
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
@@ -279,6 +284,14 @@ class TencentTraceClient:
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
LLM_OPERATION_DURATION,
latency_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
@@ -317,6 +330,13 @@ class TencentTraceClient:
"server.address": server_address,
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
GEN_AI_TOKEN_USAGE,
token_count,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
@@ -344,6 +364,13 @@ class TencentTraceClient:
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
ttft_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
@@ -371,6 +398,13 @@ class TencentTraceClient:
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_STREAMING_TIME_TO_GENERATE,
ttg_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
@@ -390,6 +424,14 @@ class TencentTraceClient:
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_TRACE_DURATION,
duration_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
@@ -474,11 +516,19 @@ class TencentTraceClient:
if self.tracer_provider:
self.tracer_provider.shutdown()
# Shutdown instance-level meter provider
if self.meter_provider is not None:
try:
self.meter_provider.shutdown() # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
pass
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")

View File

@@ -246,7 +246,7 @@ class RequestFetchAppInfo(BaseModel):
class TriggerInvokeEventResponse(BaseModel):
variables: Mapping[str, Any] = Field(default_factory=dict)
cancelled: bool | None = False
cancelled: bool = Field(default=False)
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)

View File

@@ -147,7 +147,8 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return cast(str, info["version"]["number"])
# remove any suffix like "-SNAPSHOT" from the version string
return cast(str, info["version"]["number"]).split("-")[0]
def _check_version(self):
if parse_version(self._version) < parse_version("8.0.0"):

View File

@@ -39,11 +39,13 @@ class WeaviateConfig(BaseModel):
Attributes:
endpoint: Weaviate server endpoint URL
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
grpc_endpoint: str | None = None
api_key: str | None = None
batch_size: int = 100
@@ -88,9 +90,22 @@ class WeaviateVector(BaseVector):
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
@@ -432,6 +447,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),

View File

@@ -152,13 +152,15 @@ class WordExtractor(BaseExtractor):
# Initialize a row, all of which are empty by default
row_cells = [""] * total_cols
col_index = 0
for cell in row.cells:
while col_index < len(row.cells):
# make sure the col_index is not out of range
while col_index < total_cols and row_cells[col_index] != "":
while col_index < len(row.cells) and row_cells[col_index] != "":
col_index += 1
# if col_index is out of range the loop is jumped
if col_index >= total_cols:
if col_index >= len(row.cells):
break
# get the correct cell
cell = row.cells[col_index]
cell_content = self._parse_cell(cell, image_map).strip()
cell_colspan = cell.grid_span or 1
for i in range(cell_colspan):

View File

View File

@@ -0,0 +1,82 @@
import json
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel, ValidationError
from extensions.ext_redis import redis_client
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
class TaskWrapper(BaseModel):
data: Any
def serialize(self) -> str:
return self.model_dump_json()
@classmethod
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
return cls.model_validate_json(serialized_data)
class TenantIsolatedTaskQueue:
"""
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
It uses Redis list to store tasks, and Redis key to store task waiting flag.
Support tasks that can be serialized by json.
"""
def __init__(self, tenant_id: str, unique_key: str):
self._tenant_id = tenant_id
self._unique_key = unique_key
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
def get_task_key(self):
return redis_client.get(self._task_key)
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
redis_client.setex(self._task_key, ttl, 1)
def delete_task_key(self):
redis_client.delete(self._task_key)
def push_tasks(self, tasks: Sequence[Any]):
serialized_tasks = []
for task in tasks:
# Store str list directly, maintaining full compatibility for pipeline scenarios
if isinstance(task, str):
serialized_tasks.append(task)
else:
# Use TaskWrapper to do JSON serialization for non-string tasks
wrapper = TaskWrapper(data=task)
serialized_data = wrapper.serialize()
serialized_tasks.append(serialized_data)
if not serialized_tasks:
return
redis_client.lpush(self._queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
if count <= 0:
return []
tasks = []
for _ in range(count):
serialized_task = redis_client.rpop(self._queue)
if not serialized_task:
break
if isinstance(serialized_task, bytes):
serialized_task = serialized_task.decode("utf-8")
try:
wrapper = TaskWrapper.deserialize(serialized_task)
tasks.append(wrapper.data)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
# Fall back to raw string for legacy format or invalid JSON
tasks.append(serialized_task)
return tasks

View File

@@ -7,8 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy import and_, or_, select
from core.app.app_config.entities import (
DatasetEntity,
@@ -1023,60 +1022,55 @@ class DatasetRetrieval:
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return
return filters
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition:
case "contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"{value}%"}
)
)
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}"}
)
)
filters.append(json_field.like(f"%{value}"))
case "is" | "=":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
def _fetch_model_config(

View File

@@ -210,12 +210,13 @@ class Tool(ABC):
meta=meta,
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
)
def create_variable_message(

View File

@@ -129,6 +129,7 @@ class ToolInvokeMessage(BaseModel):
class JsonMessage(BaseModel):
json_object: dict
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
blob: bytes

View File

@@ -1,16 +1,19 @@
import base64
import json
import logging
from collections.abc import Generator
from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
logger = logging.getLogger(__name__)
class MCPTool(Tool):
def __init__(
@@ -52,6 +55,11 @@ class MCPTool(Tool):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
else:
logger.warning("Unsupported content type=%s", type(content))
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
@@ -97,6 +105,10 @@ class MCPTool(Tool):
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@@ -228,29 +228,41 @@ class ToolEngine:
"""
Handle tool response
"""
result = ""
parts: list[str] = []
json_parts: list[str] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += cast(ToolInvokeMessage.TextMessage, response.message).text
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += (
parts.append(
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
parts.append(
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
if json_message.suppress_output:
continue
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
)
else:
result += str(response.message)
parts.append(str(response.message))
return result
# Add JSON parts, avoiding duplicates from text parts.
if json_parts:
existing_parts = set(parts)
parts.extend(p for p in json_parts if p not in existing_parts)
return "".join(parts)
@staticmethod
def _extract_tool_response_binary_and_text(

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -32,7 +33,6 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
@@ -618,12 +618,28 @@ class ToolManager:
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use DISTINCT ON
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
else:
# MySQL: Use window function to achieve same result
sql = """
SELECT id FROM (
SELECT id,
ROW_NUMBER() OVER (
PARTITION BY tenant_id, provider
ORDER BY is_default DESC, created_at DESC
) as rn
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
) ranked WHERE rn = 1
"""
with Session(db.engine, autoflush=False) as session:
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()

View File

@@ -117,7 +117,7 @@ class WorkflowTool(Tool):
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
yield self.create_json_message(outputs, suppress_output=True)
@property
def latest_usage(self) -> LLMUsage:

View File

@@ -208,7 +208,7 @@ class SubscriptionBuilder(BaseModel):
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
credentials: Mapping[str, str] = Field(..., description="The credentials of the subscription builder")
credentials: Mapping[str, Any] = Field(..., description="The credentials of the subscription builder")
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
credential_expires_at: int | None = Field(
default=None, description="The credential expires at of the subscription builder"
@@ -227,7 +227,7 @@ class SubscriptionBuilderUpdater(BaseModel):
name: str | None = Field(default=None, description="The name of the subscription builder")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
credentials: Mapping[str, str] | None = Field(
credentials: Mapping[str, Any] | None = Field(
default=None, description="The credentials of the subscription builder"
)
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")

View File

@@ -13,14 +13,14 @@ import contexts
from configs import dify_config
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError, PluginNotFoundError
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
from core.plugin.impl.trigger import PluginTriggerClient
from core.trigger.entities.entities import (
EventEntity,
Subscription,
UnsubscribeResult,
)
from core.trigger.errors import EventIgnoreError, TriggerPluginInvokeError
from core.trigger.errors import EventIgnoreError
from core.trigger.provider import PluginTriggerProviderController
from models.provider_ids import TriggerProviderID
@@ -189,13 +189,10 @@ class TriggerManager:
request=request,
payload=payload,
)
except EventIgnoreError as e:
except EventIgnoreError:
return TriggerInvokeEventResponse(variables={}, cancelled=True)
except PluginInvokeError as e:
logger.exception("Failed to invoke trigger event")
raise TriggerPluginInvokeError(
description=e.to_user_friendly_error(plugin_name=provider.entity.identity.name)
) from e
except Exception as e:
raise e
@classmethod
def subscribe_trigger(

View File

@@ -202,6 +202,35 @@ class SegmentType(StrEnum):
raise ValueError(f"element_type is only supported by array type, got {self}")
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: "SegmentType"):
# Lazy import to avoid circular dependency
from factories import variable_factory
match t:
case (
SegmentType.ARRAY_OBJECT
| SegmentType.ARRAY_ANY
| SegmentType.ARRAY_STRING
| SegmentType.ARRAY_NUMBER
| SegmentType.ARRAY_BOOLEAN
):
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return variable_factory.build_segment(False)
case _:
raise ValueError(f"unsupported variable type: {t}")
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have corresponding element type.

View File

@@ -114,9 +114,45 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@@ -16,7 +16,6 @@ from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
@@ -108,8 +107,8 @@ class Worker(threading.Thread):
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
node_id="unknown",
node_type=NodeType.CODE,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),

View File

@@ -6,12 +6,12 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetDocument
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
@@ -597,79 +597,79 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
if value is None and condition not in ("empty", "not empty"):
return filters
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"{value}%"}
)
)
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}"}
)
)
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
escaped_value_str = ",".join(escaped_values)
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
escaped_value_str = str(value)
filters.append(
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
**{key: metadata_name, key_value: escaped_value_str}
)
)
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
escaped_value_str = ",".join(escaped_values)
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
escaped_value_str = str(value)
filters.append(
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
**{key: metadata_name, key_value: escaped_value_str}
)
)
case "=" | "is":
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod

View File

@@ -2,7 +2,6 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
@@ -12,7 +11,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
@@ -116,7 +114,7 @@ class VariableAssignerNode(Node):
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
# Over write the variable.
@@ -143,24 +141,3 @@ class VariableAssignerNode(Node):
process_data=common_helpers.set_updated_variables({}, updated_variables),
outputs={},
)
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return BooleanSegment(value=False)
case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@@ -1,14 +0,0 @@
from core.variables import SegmentType
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.BOOLEAN: False,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_BOOLEAN: [],
}

View File

@@ -16,7 +16,6 @@ from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNod
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
@@ -249,7 +248,7 @@ class VariableAssignerNode(Node):
case Operation.OVER_WRITE:
return value
case Operation.CLEAR:
return EMPTY_VALUE_MAPPING[variable.value_type]
return SegmentType.get_zero_value(variable.value_type).to_object()
case Operation.APPEND:
return variable.value + [value]
case Operation.EXTEND:

View File

@@ -153,7 +153,11 @@ class VariablePool(BaseModel):
return None
node_id, name = self._selector_to_keys(selector)
segment: Segment | None = self.variable_dictionary[node_id].get(name)
node_map = self.variable_dictionary.get(node_id)
if node_map is None:
return None
segment: Segment | None = node_map.get(name)
if segment is None:
return None

View File

@@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

0
api/enums/__init__.py Normal file
View File

15
api/enums/cloud_plan.py Normal file
View File

@@ -0,0 +1,15 @@
from enum import StrEnum, auto
class CloudPlan(StrEnum):
"""
Enum representing user plan types in the cloud platform.
SANDBOX: Free/default plan with limited features
PROFESSIONAL: Professional paid plan
TEAM: Team collaboration paid plan
"""
SANDBOX = auto()
PROFESSIONAL = auto()
TEAM = auto()

View File

@@ -116,6 +116,7 @@ app_partial_fields = {
"access_mode": fields.String,
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
}

View File

@@ -0,0 +1,134 @@
"""
Broadcast channel for Pub/Sub messaging.
"""
import types
from abc import abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager
from typing import Protocol, Self
class Subscription(AbstractContextManager["Subscription"], Protocol):
"""A subscription to a topic that provides an iterator over received messages.
The subscription can be used as a context manager and will automatically
close when exiting the context.
Note: `Subscription` instances are not thread-safe. Each thread should create its own
subscription.
"""
@abstractmethod
def __iter__(self) -> Iterator[bytes]:
"""`__iter__` returns an iterator used to consume the message from this subscription.
If the caller did not enter the context, `__iter__` may lazily perform the setup before
yielding messages; otherwise `__enter__` handles it.”
If the subscription is closed, then the returned iterator exits without
raising any error.
"""
...
@abstractmethod
def close(self) -> None:
"""close closes the subscription, releases any resources associated with it."""
...
def __enter__(self) -> Self:
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
@abstractmethod
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the broadcast channel.
If `timeout` is specified, this method returns `None` if no message is
received within the given period. If `timeout` is `None`, the call blocks
until a message is received.
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
cancel a blocking subscription.
:param timeout: timeout for receive message, in seconds.
Returns:
bytes: The received message as a byte string, or
None: If the timeout expires before a message is received.
Raises:
SubscriptionClosed: If the subscription has already been closed.
"""
...
class Producer(Protocol):
"""Producer is an interface for message publishing. It is already bound to a specific topic.
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def publish(self, payload: bytes) -> None:
"""Publish a message to the bounded topic."""
...
class Subscriber(Protocol):
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def subscribe(self) -> Subscription:
pass
class Topic(Producer, Subscriber, Protocol):
"""A named channel for publishing and subscribing to messages.
Topics provide both read and write access. For restricted access,
use as_producer() for write-only view or as_subscriber() for read-only view.
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def as_producer(self) -> Producer:
"""as_producer creates a write-only view for this topic."""
...
@abstractmethod
def as_subscriber(self) -> Subscriber:
"""as_subscriber create a read-only view for this topic."""
...
class BroadcastChannel(Protocol):
"""A broadcasting channel is a channel supporting broadcasting semantics.
Each channel is identified by a topic, different topics are isolated and do not affect each other.
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
a specific topic, all subscription should receive the published message.
There are no restriction for the persistence of messages. Once a subscription is created, it
should receive all subsequent messages published.
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def topic(self, topic: str) -> "Topic":
"""topic returns a `Topic` instance for the given topic name."""
...

View File

@@ -0,0 +1,12 @@
class BroadcastChannelError(Exception):
"""`BroadcastChannelError` is the base class for all exceptions related
to `BroadcastChannel`."""
pass
class SubscriptionClosedError(BroadcastChannelError):
"""SubscriptionClosedError means that the subscription has been closed and
methods for consuming messages should not be called."""
pass

View File

@@ -0,0 +1,3 @@
from .channel import BroadcastChannel
__all__ = ["BroadcastChannel"]

View File

@@ -0,0 +1,200 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation.
Provides "at most once" delivery semantics for messages published to channels.
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "Topic":
return Topic(self._client, topic)
class Topic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisSubscription(Subscription):
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
self._pubsub.subscribe(self._topic)
_logger.debug("Subscribed to channel %s", self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if raw_message is None:
continue
if raw_message.get("type") != "message":
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
continue
self._enqueue_message(payload_bytes)
_logger.debug("Listener thread stopped for channel %s", self._topic)
pubsub.unsubscribe(self._topic)
pubsub.close()
_logger.debug("PubSub closed for topic %s", self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
# method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None

View File

@@ -2,6 +2,8 @@ import abc
import datetime
from typing import Protocol
import pytz
class _NowFunction(Protocol):
@abc.abstractmethod
@@ -31,3 +33,51 @@ def ensure_naive_utc(dt: datetime.datetime) -> datetime.datetime:
if dt.tzinfo is None:
return dt
return dt.astimezone(datetime.UTC).replace(tzinfo=None)
def parse_time_range(
start: str | None, end: str | None, tzname: str
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
"""
Parse time range strings and convert to UTC datetime objects.
Handles DST ambiguity and non-existent times gracefully.
Args:
start: Start time string (YYYY-MM-DD HH:MM)
end: End time string (YYYY-MM-DD HH:MM)
tzname: Timezone name
Returns:
tuple: (start_datetime_utc, end_datetime_utc)
Raises:
ValueError: When time range is invalid or start > end
"""
tz = pytz.timezone(tzname)
utc = pytz.utc
def _parse(time_str: str | None, label: str) -> datetime.datetime | None:
if not time_str:
return None
try:
dt = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M").replace(second=0)
except ValueError as e:
raise ValueError(f"Invalid {label} time format: {e}")
try:
return tz.localize(dt, is_dst=None).astimezone(utc)
except pytz.AmbiguousTimeError:
return tz.localize(dt, is_dst=False).astimezone(utc)
except pytz.NonExistentTimeError:
dt += datetime.timedelta(hours=1)
return tz.localize(dt, is_dst=None).astimezone(utc)
start_dt = _parse(start, "start")
end_dt = _parse(end, "end")
# Range validation
if start_dt and end_dt and start_dt > end_dt:
raise ValueError("start must be earlier than or equal to end")
return start_dt, end_dt

View File

@@ -177,6 +177,15 @@ def timezone(timezone_string):
raise ValueError(error)
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
if dify_config.DB_TYPE == "postgresql":
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
elif dify_config.DB_TYPE == "mysql":
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
else:
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""

Some files were not shown because too many files have changed in this diff Show More