Compare commits

..

359 Commits

Author SHA1 Message Date
hjlarry
486a30402b remove forceUpload 2026-01-23 14:33:15 +08:00
hjlarry
e105dc6289 new restore 2026-01-23 14:22:58 +08:00
hjlarry
51c8c50b82 expire leader key in redis 2026-01-22 09:30:51 +08:00
hjlarry
1b70a7e4c7 use contract for api request 2026-01-21 18:20:38 +08:00
hjlarry
eaf888b02a env var NEXT_PUBLIC_SOCKET_URL 2026-01-20 20:34:56 +08:00
hjlarry
f99ac24d5c websocket use cookie connect 2026-01-20 17:01:40 +08:00
hjlarry
bdac6f91dd add socket edit permission validate 2026-01-20 13:56:28 +08:00
hjlarry
9be496f953 fix publish workflow not sync 2026-01-20 13:20:02 +08:00
hjlarry
4acca22ff0 whether resolved sync to canvas 2026-01-20 10:12:15 +08:00
hjlarry
018175ec2d Merge branch 'feat/collaboration2' of github.com:langgenius/dify into feat/collaboration2 2026-01-19 21:54:01 +08:00
hjlarry
faa88dc2f3 fix unittests 2026-01-19 21:53:56 +08:00
hjlarry
060c7f2b45 fix pyright 2026-01-19 21:48:05 +08:00
hjlarry
acb603bff7 fix migration file 2026-01-19 21:46:40 +08:00
hjlarry
e36ee54a16 fix web style 2026-01-19 21:44:26 +08:00
autofix-ci[bot]
f3fa4f11ba [autofix.ci] apply automated fixes 2026-01-19 13:18:15 +00:00
hjlarry
cb8fc9cf2d Merge remote-tracking branch 'myori/main' into feat/collaboration2 2026-01-19 21:15:53 +08:00
hjlarry
aaa3d2d74f add unittests 2026-01-19 21:11:44 +08:00
hjlarry
c17f564718 add unittests 2026-01-19 20:41:21 +08:00
hjlarry
3389071361 add unittests 2026-01-19 20:25:47 +08:00
hjlarry
41473ff450 refactor workflow collaboration service 2026-01-19 19:56:18 +08:00
hjlarry
805bb7c468 fix node in panel sync 2026-01-19 18:01:43 +08:00
盐粒 Yanli
62ac02a568 feat: Download the uploaded files (#31068)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-19 16:48:13 +08:00
zxhlyh
2d4289a925 chore: relocate datasets api form (#31224) 2026-01-19 16:15:51 +08:00
wangxiaolei
88780c7eb7 fix: Revert "fix: fix create app xss issue" (#31219) 2026-01-19 16:07:24 +08:00
wangxiaolei
0f1db88dcb fix: fix dify-plugin-daemon error message (#31218) 2026-01-19 16:00:44 +08:00
hjlarry
995d5ccf66 fix graph not sync 2026-01-19 13:45:00 +08:00
hjlarry
0d08f7db97 fix 2026-01-18 18:36:44 +08:00
autofix-ci[bot]
6443366f50 [autofix.ci] apply automated fixes 2026-01-18 10:01:22 +00:00
非法操作
70c41a7dc3 Update api/controllers/console/app/workflow.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-18 17:59:18 +08:00
非法操作
8804623121 Update api/app_factory.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-18 17:59:10 +08:00
hjlarry
1fb6d1286f fix webtest 2026-01-18 17:27:29 +08:00
hjlarry
511df81201 fix web style 2026-01-18 13:40:12 +08:00
hjlarry
682c93f262 Merge remote-tracking branch 'myori/main' into feat/collaboration2 2026-01-18 10:28:50 +08:00
hjlarry
51c96b0b7e fix CI 2026-01-18 10:12:43 +08:00
hjlarry
224f426765 fix CI 2026-01-18 10:07:46 +08:00
autofix-ci[bot]
e9657cfb48 [autofix.ci] apply automated fixes 2026-01-17 15:00:37 +00:00
hjlarry
4200ac0da3 fix CI 2026-01-17 22:58:27 +08:00
hjlarry
434f7f3bcb fix web style 2026-01-17 22:10:10 +08:00
hjlarry
03cc196965 fix CI 2026-01-17 22:05:14 +08:00
hjlarry
25c88b3f5c fix mypy 2026-01-17 21:41:03 +08:00
hjlarry
2d94904241 fix web unittests 2026-01-17 19:43:40 +08:00
hjlarry
a99e70d96e fix CI 2026-01-17 15:55:27 +08:00
hjlarry
9eeceb2455 fix basedpyright 2026-01-17 15:54:32 +08:00
autofix-ci[bot]
7901e18fa6 [autofix.ci] apply automated fixes 2026-01-17 06:57:16 +00:00
hjlarry
2befef0b21 Merge branch 'feat/collaboration2' of github.com:langgenius/dify into feat/collaboration2 2026-01-17 14:55:22 +08:00
hjlarry
8869cd7008 fix api 2026-01-17 14:55:12 +08:00
hjlarry
91e6ae2a7d fix bug 2026-01-17 14:53:33 +08:00
hjlarry
6ab8e05a5e fix api 2026-01-17 14:47:44 +08:00
hjlarry
717f99a352 fix migration file 2026-01-17 12:54:15 +08:00
hjlarry
735cd78dc2 fix api 2026-01-17 12:45:40 +08:00
autofix-ci[bot]
c820501cbb [autofix.ci] apply automated fixes (attempt 2/3) 2026-01-17 04:29:38 +00:00
autofix-ci[bot]
43ef2395ac [autofix.ci] apply automated fixes 2026-01-17 04:27:34 +00:00
hjlarry
bb3d94f1c5 Merge remote-tracking branch 'myori/main' into feat/collaboration2 2026-01-17 12:24:37 +08:00
hjlarry
c45fbb6491 rm workflow.ts 2026-01-17 10:26:12 +08:00
hjlarry
fc291e4ca2 Merge remote-tracking branch 'myori/main' into feat/collaboration2 2026-01-17 10:22:41 +08:00
hjlarry
b549d669d6 clear logic 2026-01-15 13:17:14 +08:00
hjlarry
802b38eede fix 2026-01-15 13:16:35 +08:00
hjlarry
4b57e7bd53 fix 2026-01-15 11:42:34 +08:00
hjlarry
bfedee0532 fix 2026-01-14 16:40:52 +08:00
hjlarry
1845938e70 fix type issue 2026-01-13 22:18:54 +08:00
hjlarry
fad81ab85e fix type issue 2026-01-13 22:11:36 +08:00
hjlarry
d1c64f5c74 add toast when disconnected 2026-01-13 22:08:59 +08:00
hjlarry
7f6c93bdce reduce CURSOR_THROTTLE_MS 2026-01-13 22:08:07 +08:00
hjlarry
7730c88c74 fix leader election concurrently 2026-01-13 18:01:12 +08:00
hjlarry
ac6b540fd8 CORS config 2026-01-13 17:50:16 +08:00
hjlarry
8c9276370c remove console.log 2026-01-13 17:46:53 +08:00
hjlarry
b91370aff7 fix next config 2026-01-13 17:40:04 +08:00
hjlarry
30424df7ce uuid v7 2026-01-13 17:20:02 +08:00
hjlarry
14f7f4758a fix error display 2026-01-13 17:19:52 +08:00
hjlarry
79c19983e0 refactor: fix N+1 query issue in workflow comments 2026-01-13 16:56:54 +08:00
hjlarry
aeb3fc6729 add backend logging 2026-01-13 16:25:54 +08:00
hjlarry
0c18d4e058 fix duplicated status 2026-01-13 15:59:59 +08:00
hjlarry
bd597497e7 prevent comment thread pinch 2025-11-27 15:37:46 +08:00
hjlarry
be1f841b37 control panel should be z-60 2025-11-24 16:27:37 +08:00
hjlarry
d98a428100 Revert "fix model config panel z-index"
This reverts commit f85bf0867c.
2025-11-24 16:23:10 +08:00
hjlarry
26d330e744 setting dialog should be z-index 60 2025-11-24 16:19:29 +08:00
hjlarry
61bed38afb Reapply "fix system model setting modal index"
This reverts commit 16fbc6b270.
2025-11-24 16:16:56 +08:00
hjlarry
16fbc6b270 Revert "fix system model setting modal index"
This reverts commit fe132de3c8.
2025-11-24 16:16:45 +08:00
hjlarry
fe132de3c8 fix system model setting modal index 2025-11-24 16:12:18 +08:00
hjlarry
f85bf0867c fix model config panel z-index 2025-11-24 16:10:46 +08:00
hjlarry
b441a7fbc4 fix style 2025-11-18 10:31:56 +08:00
hjlarry
8497d296b1 feat: can drag avatar to move the comment input 2025-11-18 09:53:15 +08:00
hjlarry
3ee2508ec8 fix comment input also not allow to zoomin canvas 2025-11-17 16:17:34 +08:00
hjlarry
ff8d5ac4b5 fix gesture zoom in 2025-11-17 15:37:43 +08:00
hjlarry
7fc98b2183 fix sync of webhook node 2025-11-14 11:31:08 +08:00
hjlarry
a4adafd8ad remove the single env button 2025-11-14 11:00:33 +08:00
hjlarry
c1bc3aeab9 fix migration file 2025-11-14 10:58:16 +08:00
hjlarry
edf962cdb5 Merge branch 'feat/collaboration' into feat/collaboration2 2025-11-13 15:31:21 +08:00
hjlarry
2fa13cdf86 if session unauthorized, rejoin 2025-11-11 16:38:55 +08:00
hjlarry
39de7673eb add redis key expire time for collaboration 2025-11-11 16:13:05 +08:00
hjlarry
d930d8cc4a fix setting dialog z-index 2025-11-10 18:02:36 +08:00
hjlarry
97626a3ba5 can't zoomOnPinch when mouse over comment preview 2025-11-07 09:27:49 +08:00
hjlarry
b7f7d04639 fix comment input mention not display avatar 2025-11-05 18:09:42 +08:00
hjlarry
13674bd859 comment input mode click empty place can close 2025-11-05 17:41:10 +08:00
hjlarry
fb9cbc0471 comment mode can't click node 2025-11-05 14:14:36 +08:00
hjlarry
2f60288d86 fix: resize workflow canvas cause incorrect comment position 2025-11-05 14:08:21 +08:00
hjlarry
ee3ded0fc2 fix control layer 2025-10-22 10:25:31 +08:00
hjlarry
351bad9ec4 fix minimap disable collobroation 2025-10-22 10:21:25 +08:00
hjlarry
9bf7473bbf hide comments when disable collaboration 2025-10-22 10:10:23 +08:00
hjlarry
fa09c88f5c add CollaborationEnabled for comment shortcut 2025-10-22 09:59:43 +08:00
hjlarry
83df78d0c8 hide comments icon when disable collabrotion mode 2025-10-22 09:50:37 +08:00
hjlarry
79266f7302 add note node sync data 2025-10-21 15:34:44 +08:00
hjlarry
7fecc7236c add more collaboration manager unit tests 2025-10-21 14:37:31 +08:00
hjlarry
9c7f6b7b71 add crdt provider unittests 2025-10-21 14:27:13 +08:00
hjlarry
b46da93e99 add unittests for event-emitter 2025-10-21 14:12:13 +08:00
hjlarry
e299a1fb20 add ws manager unit tests 2025-10-21 14:09:25 +08:00
hjlarry
122033cadb sort out code 2025-10-21 12:27:11 +08:00
hjlarry
df9bd1b3b5 add Parameters of ParametersExtractor node sync 2025-10-21 12:14:48 +08:00
hjlarry
f74492eb59 add prompt_template of LLM node sync 2025-10-21 12:00:42 +08:00
hjlarry
eaf1ae37dd add ENABLE_COLLABORATION_MODE 2025-10-21 11:46:28 +08:00
hjlarry
8e3b412ff6 fix websocket cookie auth 2025-10-21 11:46:00 +08:00
hjlarry
ba17f576e9 Merge remote-tracking branch 'myori/main' into feat/collaboration 2025-10-21 08:47:01 +08:00
lyzno1
9415ce4512 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-20 10:04:13 +08:00
lyzno1
239536933b Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-17 19:33:40 +08:00
hjlarry
80b34598e9 try to fix start node collaboration 2025-10-16 10:18:37 +08:00
lyzno1
9c66b92c34 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-15 21:08:08 +08:00
lyzno1
79872ea5e2 Refine workflow comment avatar highlight ring 2025-10-15 14:58:03 +08:00
lyzno1
cbf181bd76 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-15 11:06:23 +08:00
lyzno1
1393d21858 fix(web): adjust online users badge sizing and add pointer cursor to chevron 2025-10-15 11:06:04 +08:00
lyzno1
3a46b7bd18 fix(web): restyle workflow online-users avatar stack and dropdown 2025-10-15 10:48:38 +08:00
lyzno1
0bbfd81d26 fix: tooltip font 2025-10-15 10:35:42 +08:00
lyzno1
86db517142 fix(web): make workflow online-users dropdown click-based with revised spacing 2025-10-15 10:34:00 +08:00
lyzno1
50151f4007 fix(web): adjust workflow online-users icon and label styles 2025-10-15 10:21:54 +08:00
lyzno1
0395d1f91f Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-15 10:02:55 +08:00
lyzno1
5f4c1e4057 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-15 09:33:54 +08:00
hjlarry
d14413f3b0 comment click caculate the panel width 2025-10-15 09:11:44 +08:00
lyzno1
4fd968270c Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-14 18:56:27 +08:00
hjlarry
708a7dd362 fix comment mode can't drag node 2025-10-14 17:31:03 +08:00
hjlarry
cd85b75312 fix control panel hovered by comment icon 2025-10-14 17:16:33 +08:00
hjlarry
d685da377e fix minimap 2025-10-14 17:11:22 +08:00
hjlarry
8583992d23 when new user connected should rebroadcast the graph data 2025-10-14 16:57:02 +08:00
hjlarry
23fec75c90 cache the new created comment 2025-10-14 11:21:18 +08:00
hjlarry
ebe7303894 fix loop variable not sync well 2025-10-14 10:10:34 +08:00
hjlarry
79fb977f10 fix loop/iteration incorrect nodes width 2025-10-14 09:54:37 +08:00
lyzno1
c0af3414a3 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-14 07:54:05 +08:00
hjlarry
1857d37fae sync app published 2025-10-13 16:42:17 +08:00
hjlarry
60fdbb56a9 fix all lines missing 2025-10-13 16:38:50 +08:00
hjlarry
4c7853164d fix mcp server edit modal disappear 2025-10-13 16:36:39 +08:00
hjlarry
6c7a3ce4bb sync workflow publish to mcp server 2025-10-13 14:07:26 +08:00
lyzno1
a9e74b21f1 fix: increase ContentDialog z-index to display above workflow operators
The collaboration feature increased workflow operator z-index from z-10 to z-[60].
This caused the AppInfo ContentDialog (z-30) to appear below the operator buttons.
Increased ContentDialog z-index to z-[70] to ensure proper layer hierarchy.
2025-10-13 14:00:28 +08:00
lyzno1
e6730f7164 fix: dropdown menu border 2025-10-13 13:15:54 +08:00
lyzno1
3344723393 fix: prevent Enter key from triggering submit during IME composition
Add isComposing check at the start of handleKeyDown to ignore keyboard events during IME (Chinese/Japanese/Korean) input composition. This follows the existing pattern used in tag-management component and prevents premature form submission when users press Enter to confirm IME candidates.
2025-10-13 13:09:52 +08:00
lyzno1
c571185a91 fix: extract @mention highlighting from content in real-time to persist after edit 2025-10-13 13:03:55 +08:00
lyzno1
325c1cfa41 fix: prevent Save button flash by maintaining loading state until edit closes 2025-10-13 12:56:18 +08:00
lyzno1
1069421753 refactor: replace keyboard shortcut icons with custom EnterKey icon 2025-10-13 12:52:07 +08:00
lyzno1
b33a97ea5b style: update comment thread UI with design specs
- Fix edit bubble: keep avatar visible and match ThreadMessage layout
- Update edit container: rounded-xl, p-1, shadow-md, backdrop-blur
- Add keyboard shortcut icons (Cmd+Enter) to Save button
- Fix hover background: full-width with -mx-4 negative margin technique
- Apply design tokens consistently across components
2025-10-13 12:42:41 +08:00
lyzno1
d2c1d4c337 style: update mention dropdown UI to match design specs
- Update container: rounded-xl, border-0.5px, backdrop-blur, bg opacity 95%
- Update items: rounded-md with asymmetric padding (py-1 pl-2 pr-3)
- Use project design tokens (shadow-lg, bg-state-base-hover)
2025-10-13 12:24:28 +08:00
lyzno1
67762cf1d8 chore: resolve merge conflict in pnpm-lock.yaml
Merged origin/main into feat/collaboration and resolved dependency lock file conflicts by regenerating pnpm-lock.yaml through clean install.

Changes:
- Resolved eslint version differences (9.36.0 vs 9.35.0)
- Updated lock file reflects current dependency resolution
- All other changes from main branch successfully merged
2025-10-13 11:53:43 +08:00
hjlarry
eadce0287c app meta sync 2025-10-13 11:49:54 +08:00
hjlarry
ecaff5b63f fix loop var change cause collaboration crash 2025-10-13 10:06:50 +08:00
hjlarry
a300c9ef96 fix canvas empty on the bottom 2025-10-13 09:38:59 +08:00
lyzno1
44fe71e4db fix: ensure comment thread always scrolls to bottom on first render 2025-10-12 13:27:42 +08:00
lyzno1
0ac32188c5 feat: implement comprehensive focus management for comment thread
- Add forwardRef support to MentionInput to expose textarea ref
- Auto-focus reply input when thread opens (100ms delay)
- Restore focus after reply submission and edit operations
- Add Esc key handler to close thread with smart guards
- Enhance accessibility with ARIA attributes (dialog, modal, labelledby)
- Improve keyboard navigation and user experience

Implements P0-P3 priorities following WCAG 2.1 AA accessibility standards
2025-10-12 13:21:57 +08:00
lyzno1
9aaace706b feat: optimize comments panel filter UI and interaction logic 2025-10-12 13:04:24 +08:00
lyzno1
b22de5a824 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-12 13:04:07 +08:00
lyzno1
97463661c1 fix: translations 2025-10-11 20:33:55 +08:00
lyzno1
239a11855a fix: prevent dropdown from closing when showing inline delete confirmation
Use pre-rendering strategy with CSS visibility control instead of conditional rendering to avoid race condition between React state update and PortalToFollowElem's click-outside detection.
2025-10-11 20:21:52 +08:00
lyzno1
0632557d91 feat: use inline delete confirm for comment reply deletion(second time) 2025-10-11 18:37:41 +08:00
lyzno1
44be7d4c51 Revert "feat: use inline delete confirm for comment reply deletion"
This reverts commit a077a3f609.
2025-10-11 18:24:15 +08:00
lyzno1
efb4a9d327 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-11 18:18:40 +08:00
lyzno1
a077a3f609 feat: use inline delete confirm for comment reply deletion 2025-10-11 18:06:31 +08:00
lyzno1
3ccec0aab0 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-11 17:21:05 +08:00
hjlarry
3006133f0e sync node title 2025-10-11 15:48:51 +08:00
lyzno1
79beb25530 feat: add tooltips and improve delete button styling in CommentThread
- Add compact tooltips to Delete, Resolve, Previous, and Next buttons
- Change delete button hover to red background and text
- Use existing i18n translations for tooltip content
2025-10-11 15:22:37 +08:00
lyzno1
b47b228164 fix: align dropdown menu styles with design specs in CommentThread
- Update background to blur variant with backdrop filter
- Change border radius from lg to xl (12px)
- Add rounded corners to menu items to prevent hover overflow
2025-10-11 15:10:57 +08:00
lyzno1
be91db14d9 fix: add hover effect to first message in CommentThread
Wrap the root comment message with the same hover container as replies to ensure consistent hover behavior across all messages.
2025-10-11 15:08:27 +08:00
lyzno1
120893209e fix: align CommentPreview styles with design specs
- Update border radius to 24px with 3px bottom-left corner
- Change border width to 0.5px
- Add backdrop blur effect with bg-blur variant
- Replace custom shadow with standard shadow-lg
- Maintain proper Tailwind utility class usage
2025-10-11 15:02:06 +08:00
lyzno1
f19630bcf5 Merge remote-tracking branch 'origin/main' into feat/collaboration 2025-10-11 14:43:20 +08:00
lyzno1
9d93fda471 refactor: separate loading states for comment operations
Separate loading states to distinguish between different operations:
- activeCommentDetailLoading: loading comment details, delete/resolve operations
- replySubmitting: sending new replies
- replyUpdating: editing existing replies

Changes:
- Add replySubmitting and replyUpdating states to comment store
- Restore full-screen loading overlay for comment detail loading
- Use inline spinner (RiLoader2Line) in send/save buttons for reply operations
- Update loading state usage in handleCommentReply and handleCommentReplyUpdate
- Pass separated loading states from workflow index to CommentThread component

Benefits:
- UI clarity: different loading states have appropriate visual feedback
- Better UX: users can still navigate while sending replies
- Clear separation of concerns: each operation has its own loading state
2025-10-11 14:34:35 +08:00
lyzno1
d986659add chore: replace Chinese/Japanese comments with English translations 2025-10-11 14:20:37 +08:00
lyzno1
00dab7ca5f feat: improve mention input loading state and prevent button flash on submit 2025-10-11 14:20:37 +08:00
lyzno1
a4add403fb Fix MentionInput layout and improve comment hover styling 2025-10-11 14:20:37 +08:00
lyzno1
e9cdc96c74 feat: prevent duplicate @ insertion in mention input with visual feedback 2025-10-11 14:20:37 +08:00
lyzno1
6af1fea232 fix: update mention button icon color for better visibility in light mode 2025-10-11 14:20:37 +08:00
lyzno1
45d5d9e44f fix: mention input cannot scroll 2025-10-11 14:20:36 +08:00
lyzno1
376a084aca refactor: use PortalToFollowElem for dropdown with scroll handling
- Replace inline dropdown with PortalToFollowElem to prevent container overflow
- Use z-[100] for dropdown to ensure proper stacking
- Remove redundant outside click handler (handled by PortalToFollowElem)
- Add scroll event listener to auto-close dropdown when scrolling
- Dropdown now renders via portal outside message container
2025-10-11 14:20:36 +08:00
lyzno1
d1f42d47fe fix: improve dropdown menu hover and positioning 2025-10-11 14:20:36 +08:00
lyzno1
64b8fd87ad fix: improve dropdown menu positioning and z-index 2025-10-11 14:20:36 +08:00
lyzno1
364be48248 feat: add smooth scroll to comment thread 2025-10-11 14:20:36 +08:00
hjlarry
2bce046278 fix node error default value not sync 2025-10-11 14:17:58 +08:00
hjlarry
1120d552b6 fix knowledge node add/delete dataset not sync 2025-10-11 14:09:37 +08:00
hjlarry
69cab0817f fix comment input hoverd by comment content 2025-10-11 10:41:28 +08:00
hjlarry
c4d03bf378 change event type name of websocket 2025-10-11 09:07:02 +08:00
hjlarry
6c039be2ca fix jump to other page not disconnect websocket 2025-10-10 16:51:57 +08:00
hjlarry
832dabc8a4 only author can move the comment position 2025-10-10 15:58:01 +08:00
hjlarry
1da2028d9d keep the previous private property when import node data 2025-10-10 13:26:55 +08:00
hjlarry
7c3f6dcc8d use cloneDeep instead of json.parse 2025-10-10 10:34:00 +08:00
hjlarry
1472884eb5 sync the create/delete app in the list page 2025-10-10 10:18:23 +08:00
hjlarry
ec22b1c706 fix user uploaded avatar display incorrect 2025-10-09 17:40:20 +08:00
hjlarry
a1712df7c2 comment author avatar is the first avatar 2025-10-09 17:12:37 +08:00
hjlarry
a40e11cb3e only can edit own replies 2025-10-09 17:02:39 +08:00
hjlarry
61c46bea40 fix missing i18n 2025-10-09 16:55:53 +08:00
hjlarry
1c5c28a82c fix switch to cursor mode comment input still exists 2025-10-09 16:36:20 +08:00
hjlarry
2310145937 comment reply auto scoll down to bottom 2025-10-09 15:50:23 +08:00
hjlarry
6a9c9cadd0 fix comment hover the variable panel 2025-10-09 15:44:56 +08:00
hjlarry
7774ff9944 fix version not display 2025-10-09 15:07:36 +08:00
hjlarry
33d4c95470 can update comment position 2025-10-05 10:17:04 +08:00
hjlarry
659cbc05a9 fix mention-input in the bottom of the browser 2025-10-04 21:24:27 +08:00
hjlarry
6ce65de2cd fix merged main issues 2025-10-04 21:11:59 +08:00
hjlarry
93b2eb3ff6 Merge remote-tracking branch 'myori/main' into p284 2025-10-04 15:28:29 +08:00
hjlarry
bf71300635 improve comment cursor move 2025-10-04 14:36:10 +08:00
hjlarry
37ecd4a0bc fix @ input problem 2025-10-04 13:39:00 +08:00
hjlarry
827a1b181b fix comment icon position 2025-10-04 13:25:59 +08:00
hjlarry
c4e7cb75cd cache the mentioned users 2025-10-04 11:22:02 +08:00
hjlarry
98e4bfcda8 click comment icon not switch to comment mode 2025-10-03 23:36:56 +08:00
hjlarry
ee48ca7671 fix default comment icon 2025-09-30 15:23:43 +08:00
hjlarry
4ba6de1116 add leader session more check 2025-09-29 14:01:42 +08:00
hjlarry
bfbe636555 fix docker file websocket mode 2025-09-29 13:35:10 +08:00
hjlarry
54ae43ef47 sync children node data 2025-09-26 14:07:34 +08:00
hjlarry
7a74b5ee3e fix add child node resize parent node size 2025-09-26 14:04:50 +08:00
hjlarry
0e9d43d605 http node data sync 2025-09-26 11:13:20 +08:00
hjlarry
cc54363c27 sync the prompt editor 2025-09-26 10:48:00 +08:00
hjlarry
89affe3139 fix opened panel be affected 2025-09-26 09:20:33 +08:00
hjlarry
2c4977dbb1 fix bug 2025-09-25 16:56:06 +08:00
hjlarry
e240175116 sync nodes 2025-09-25 16:31:46 +08:00
hjlarry
2398ed6fe8 fix update env api update time error 2025-09-25 16:28:33 +08:00
hjlarry
a8420ac33c add fragment to prevent list missing key 2025-09-25 09:52:08 +08:00
hjlarry
8470be6411 improve delete comment i18n 2025-09-25 09:41:59 +08:00
hjlarry
3d6295c622 refactor delete comment and reply 2025-09-25 09:35:46 +08:00
17hz
ff2f7206f3 bump nextjs to 15.5 and turbopack for development mode (#24346)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
2025-09-25 09:10:09 +08:00
hjlarry
b937fc8978 app online user list 2025-09-24 17:03:33 +08:00
hjlarry
86a9a51952 add comment preview 2025-09-24 12:54:54 +08:00
hjlarry
4188c9a1dd fix dark theme 2025-09-24 10:08:33 +08:00
hjlarry
8c00f89e36 add icon to zoom2fit 2025-09-23 22:22:28 +08:00
hjlarry
9e8ac5c96b refactor cursor and add hide comment 2025-09-23 22:13:02 +08:00
hjlarry
05a67f4716 add display/hide collaborator cursors 2025-09-23 17:37:40 +08:00
hjlarry
f49476a206 add show/hide minimap 2025-09-23 17:20:41 +08:00
hjlarry
c1e9c56e25 fix style 2025-09-23 17:19:36 +08:00
hjlarry
d5dd73cacf add i18n for comment 2025-09-23 16:19:04 +08:00
hjlarry
21f7a49b4e fix restore page crash 2025-09-23 15:44:57 +08:00
hjlarry
716ac04e13 add comment shortcut 2025-09-23 15:40:53 +08:00
hjlarry
c28a32fc47 fix handleModeComment 2025-09-23 15:35:28 +08:00
hjlarry
31cba28e8a improve comment cursor icon 2025-09-23 15:28:22 +08:00
hjlarry
48cd7e6481 input comment should not cancel comment mode 2025-09-23 14:48:31 +08:00
hjlarry
47aba1c9f9 fix style 2025-09-23 14:41:34 +08:00
hjlarry
0f3f8bc0d9 make mention input can display name different color 2025-09-23 11:38:38 +08:00
hjlarry
e0df12c212 fix mentioned names color 2025-09-23 11:24:17 +08:00
hjlarry
eb448d9bb8 fix avatar background color 2025-09-23 11:09:02 +08:00
hjlarry
0ba77f13db fix avatar inset 2025-09-23 10:46:18 +08:00
hjlarry
f0a2eb843c fix user cursor should not over the panel 2025-09-23 10:35:16 +08:00
hjlarry
5cf3d9e4d9 fix nginx config 2025-09-22 14:21:07 +08:00
hjlarry
41958f55cd fix CSP 2025-09-22 14:20:11 +08:00
hjlarry
600ad232e1 fix config 2025-09-22 14:20:11 +08:00
hjlarry
7a3825cfce fix docker config 2025-09-22 14:20:11 +08:00
hjlarry
9519653422 change default ws url 2025-09-22 14:20:11 +08:00
hjlarry
efa2307c73 change default ws url 2025-09-22 14:20:11 +08:00
hjlarry
068fa3d0e3 fix CI 2025-09-22 14:20:11 +08:00
hjlarry
13d8dbd542 fix CI 2025-09-22 14:20:08 +08:00
hjlarry
b442ba8b2b fix UserAvatarList background color 2025-09-19 12:07:07 +08:00
hjlarry
10e36d2355 add avatar on canvas node 2025-09-19 10:43:28 +08:00
hjlarry
13c53fedad add avatar display on node 2025-09-19 10:07:01 +08:00
hjlarry
4bda1bd884 open node panel not affect others 2025-09-18 17:42:02 +08:00
hjlarry
3abe7850d6 fix migration file 2025-09-18 16:30:40 +08:00
hjlarry
b50284d864 fix merge problem 2025-09-18 15:45:53 +08:00
hjlarry
81c6e52401 Merge remote-tracking branch 'origin/p254' into p284 2025-09-18 15:14:55 +08:00
hjlarry
847d257366 Merge branch 'p254' into p284 2025-09-18 14:50:59 +08:00
hjlarry
687662cf1f comment sync 2025-09-18 13:27:27 +08:00
hjlarry
6432d98469 improve the icon display on canvas 2025-09-18 11:49:43 +08:00
hjlarry
088ccf8b8d add UserAvatarList component 2025-09-18 09:47:07 +08:00
hjlarry
e8683bf957 fix comment cursor position 2025-09-18 09:17:45 +08:00
hjlarry
4653981b6b not display more icon when in edit mode 2025-09-17 20:45:54 +08:00
hjlarry
e2547413d3 fix edit input mouse pos 2025-09-17 20:40:59 +08:00
hjlarry
ea17f41b5b refactor reply code 2025-09-17 20:29:23 +08:00
hjlarry
29178d8adf can edit and delete a reply 2025-09-17 17:44:09 +08:00
hjlarry
7e86ead574 upgrade style 2025-09-17 16:41:10 +08:00
hjlarry
72debcb228 refactor mention input 2025-09-17 16:28:47 +08:00
hjlarry
72737dabc7 fix at can't click bug 2025-09-17 14:50:05 +08:00
hjlarry
f6e5cb4381 improve comment detail 2025-09-17 14:34:36 +08:00
hjlarry
ffad3b5fb1 comment detail window fix height 2025-09-17 13:45:56 +08:00
hjlarry
cba9fc3020 add comment reply 2025-09-17 12:50:42 +08:00
hjlarry
e776accaf3 add top operation buttons of comment detail 2025-09-17 10:45:15 +08:00
hjlarry
3eac26929a sync the comment panel and canvas 2025-09-17 09:13:31 +08:00
hjlarry
4d3adec738 click canvas icon display the active comment detail 2025-09-17 09:01:16 +08:00
hjlarry
89bed479e4 improve comment panel 2025-09-16 17:25:51 +08:00
hjlarry
fdd673a3a9 improve comments panel 2025-09-16 13:39:31 +08:00
hjlarry
22f6d285c7 fix comment cursor in panel incorrect 2025-09-16 10:20:12 +08:00
hjlarry
10aa16b471 add workflow comment panel 2025-09-16 09:51:12 +08:00
hjlarry
b3838581fd improve mention 2025-09-15 17:13:46 +08:00
hjlarry
affbe7ccdb can mention user in the create comment 2025-09-15 16:42:31 +08:00
hjlarry
dd8577f832 comments display on canvas 2025-09-15 14:16:06 +08:00
hjlarry
d7f5da5df4 display comments avatar on the canvas 2025-09-15 11:41:06 +08:00
hjlarry
9fda130b3a fix click comment once more then esc not work 2025-09-15 11:11:07 +08:00
hjlarry
72cdbdba0f fix chat input style 2025-09-15 09:20:06 +08:00
hjlarry
b92a153902 refactor code 2025-09-14 13:03:08 +08:00
hjlarry
9f2927979b fix comment cursor icon 2025-09-14 12:50:18 +08:00
hjlarry
75257232c3 add create comment frontend 2025-09-14 12:10:37 +08:00
hjlarry
1721314c62 add frontend comment service 2025-09-13 17:57:19 +08:00
hjlarry
fc230bcc59 add force update workflow to support restore 2025-09-12 16:27:12 +08:00
hjlarry
b4636ddf44 add leader restore workflow 2025-09-12 15:34:41 +08:00
hjlarry
b1140301a4 sync import dsl 2025-09-12 14:46:40 +08:00
hjlarry
58cd785da6 use const for cursor move config 2025-09-11 09:36:22 +08:00
hjlarry
2035186cd2 click avatar to follow user cursor position 2025-09-11 09:26:05 +08:00
hjlarry
53ba6aadff cursor pos transform to canvas 2025-09-11 09:07:03 +08:00
hjlarry
f091868b7c use new get avatar api 2025-09-10 15:15:43 +08:00
hjlarry
89bedae0d3 remove the test code for develop collaboration 2025-09-10 14:27:20 +08:00
hjlarry
c8acc48976 ruff format 2025-09-10 14:25:37 +08:00
hjlarry
21fee59b22 use new features update api 2025-09-10 14:24:38 +08:00
hjlarry
957a8253f8 change user list to conversation var panel left 2025-09-10 09:26:38 +08:00
hjlarry
d5fc3e7bed add new conversation vars update api 2025-09-10 09:24:22 +08:00
hjlarry
ab438b42da use new env variables update api 2025-09-10 09:07:55 +08:00
hjlarry
3867fece4a mcp server update 2025-09-09 15:01:38 +08:00
hjlarry
2b908d4fbe add app state update 2025-09-09 14:24:37 +08:00
hjlarry
8ff062ec8b change user default color 2025-09-09 10:20:02 +08:00
hjlarry
294fc41aec add redo undo manager of CRDT 2025-09-09 09:58:55 +08:00
hjlarry
684f7df158 node data use crdt data 2025-09-08 14:46:28 +08:00
hjlarry
c3287755e3 add request leader to sync graph 2025-09-08 09:00:20 +08:00
hjlarry
9f97f4d79e fix cursor style 2025-09-06 15:54:19 +08:00
hjlarry
34eb421649 add currentUserId is me 2025-09-06 12:27:54 +08:00
hjlarry
850b05573e add dropdown users list 2025-09-06 12:01:49 +08:00
hjlarry
6ec8bfdfee add mouse over avatar display username 2025-09-06 11:29:45 +08:00
hjlarry
81638c248e use one getUserColor func 2025-09-06 11:22:59 +08:00
hjlarry
2e11b1298e add online users avatar 2025-09-06 11:19:47 +08:00
hjlarry
20320f3a27 show online users on the canvas 2025-09-06 00:08:17 +08:00
hjlarry
4019c12d26 fix missing import 2025-09-05 22:20:07 +08:00
hjlarry
cf72184ce4 each browser tab session a ws connected obj 2025-09-05 22:19:16 +08:00
hjlarry
ca8d15bc64 add mention user list api 2025-08-31 13:42:59 +08:00
hjlarry
a91c897fd3 improve code 2025-08-31 00:43:34 +08:00
hjlarry
816bdf0320 add delete comment and reply 2025-08-31 00:28:01 +08:00
hjlarry
d4a6acbd99 add update reply 2025-08-30 23:49:27 +08:00
hjlarry
e421db4005 add resolve comment 2025-08-30 22:37:01 +08:00
hjlarry
9067c2a9c1 add update comment 2025-08-22 17:48:14 +08:00
hjlarry
9f7321ca1a add create reply 2025-08-22 17:33:47 +08:00
hjlarry
5fa01132b9 add create and list comment api 2025-08-22 16:47:08 +08:00
hjlarry
e082b6d599 add workflow comment models 2025-08-22 11:28:26 +08:00
hjlarry
d44be2d835 add leader submit graph data 2025-08-21 17:53:39 +08:00
hjlarry
7dc8557033 add Leader election 2025-08-21 16:17:16 +08:00
hjlarry
72037a1865 improve cursors logic 2025-08-21 14:27:41 +08:00
hjlarry
2d1621c43d add leader but not review 2025-08-08 14:54:18 +08:00
hjlarry
d1a5db3310 rm useCollaborativeCursors compoent 2025-08-07 18:03:12 +08:00
hjlarry
ad8fd8fecc clone the node to avoid loro recursive 2025-08-07 17:45:38 +08:00
hjlarry
be74b76079 refactor websocket init 2025-08-07 17:31:12 +08:00
hjlarry
dd64af728f refactor the cursors component 2025-08-07 14:29:23 +08:00
hjlarry
e43b46786d refactor all the frontend code 2025-08-07 10:58:53 +08:00
hjlarry
3f3b37b843 refactor to support mutli websocket connections 2025-08-06 17:05:39 +08:00
hjlarry
2ecf9f6ddf add features collaboration 2025-08-06 10:58:32 +08:00
hjlarry
48c069fe68 support env vars collaborate 2025-08-05 15:22:22 +08:00
hjlarry
9c5c597c85 support empty collaboration event data 2025-08-05 15:21:41 +08:00
hjlarry
c2eec8545d collaborate conversation vars 2025-08-05 14:24:51 +08:00
hjlarry
2395d4be26 fix imported updates also broadcast to other clients 2025-08-05 10:21:22 +08:00
hjlarry
9455476705 handle edge delete 2025-08-04 14:17:59 +08:00
hjlarry
494e223706 some operations don't need to broadcast 2025-08-03 14:18:48 +08:00
hjlarry
348fd18230 refactor collaboration 2025-08-03 13:34:07 +08:00
hjlarry
7233b4de55 the initial data to collaboration store 2025-07-31 16:27:01 +08:00
hjlarry
af6df05685 add setNodes and setEdges of collaboration store 2025-07-31 15:25:50 +08:00
hjlarry
965b65db6e use loro for crdt data 2025-07-31 14:02:53 +08:00
hjlarry
4cc01c8aa8 try a lot for yjs, but update data still not work... 2025-07-30 14:36:29 +08:00
hjlarry
41372168b6 refactor code 2025-07-23 10:04:16 +08:00
hjlarry
f4438b0a08 support mouse display 2025-07-22 18:08:35 +08:00
hjlarry
897c842637 ruff format 2025-07-21 16:13:04 +08:00
hjlarry
ee86ceb906 fix gunicorn gvent 2025-07-21 16:09:51 +08:00
hjlarry
e298732499 refactor code 2025-07-21 16:07:22 +08:00
hjlarry
4081937e22 migrate to python-socketio 2025-07-21 14:57:28 +08:00
hjlarry
f9aedb2118 add collaborate event 2025-07-21 11:10:23 +08:00
hjlarry
74b4719af8 support broadcast online users 2025-07-18 15:02:34 +08:00
hjlarry
2f35cc9188 add online users backend api and frontend submit cursor pos 2025-07-18 11:17:08 +08:00
hjlarry
2f966d8c38 fix websocket auth 2025-07-17 17:16:52 +08:00
hjlarry
b0868d9136 fix websocket auth 2025-07-17 17:16:38 +08:00
hjlarry
37440e9416 ruff format 2025-07-17 15:37:13 +08:00
hjlarry
0d7d27ec0b establish websocket connection 2025-07-17 15:36:50 +08:00
373 changed files with 17295 additions and 19606 deletions

1
.gitignore vendored
View File

@@ -209,7 +209,6 @@ api/.vscode
.history
.idea/
web/migration/
# pnpm
/.pnpm-store

View File

@@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# Collaboration mode toggle
ENABLE_COLLABORATION_MODE=false
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

View File

@@ -0,0 +1,52 @@
## Purpose
`api/controllers/console/datasets/datasets_document.py` contains the console (authenticated) APIs for managing dataset documents (list/create/update/delete, processing controls, estimates, etc.).
## Storage model (uploaded files)
- For local file uploads into a knowledge base, the binary is stored via `extensions.ext_storage.storage` under the key:
- `upload_files/<tenant_id>/<uuid>.<ext>`
- File metadata is stored in the `upload_files` table (`UploadFile` model), keyed by `UploadFile.id`.
- Dataset `Document` records reference the uploaded file via:
- `Document.data_source_info.upload_file_id`
## Download endpoint
- `GET /datasets/<dataset_id>/documents/<document_id>/download`
- Only supported when `Document.data_source_type == "upload_file"`.
- Performs dataset permission + tenant checks via `DocumentResource.get_document(...)`.
- Delegates `Document -> UploadFile` validation and signed URL generation to `DocumentService.get_document_download_url(...)`.
- Applies `cloud_edition_billing_rate_limit_check("knowledge")` to match other KB operations.
- Response body is **only**: `{ "url": "<signed-url>" }`.
- `POST /datasets/<dataset_id>/documents/download-zip`
- Accepts `{ "document_ids": ["..."] }` (upload-file only).
- Returns `application/zip` as a single attachment download.
- Rationale: browsers often block multiple automatic downloads; a ZIP avoids that limitation.
- Applies `cloud_edition_billing_rate_limit_check("knowledge")`.
- Delegates dataset permission checks, document/upload-file validation, and download-name generation to
`DocumentService.prepare_document_batch_download_zip(...)` before streaming the ZIP.
## Verification plan
- Upload a document from a local file into a dataset.
- Call the download endpoint and confirm it returns a signed URL.
- Open the URL and confirm:
- Response headers force download (`Content-Disposition`), and
- Downloaded bytes match the uploaded file.
- Select multiple uploaded-file documents and download as ZIP; confirm all selected files exist in the archive.
## Shared helper
- `DocumentService.get_document_download_url(document)` resolves the `UploadFile` and signs a download URL.
- `DocumentService.prepare_document_batch_download_zip(...)` performs dataset permission checks, batches
document + upload file lookups, preserves request order, and generates the client-visible ZIP filename.
- Internal helpers now live in `DocumentService` (`_get_upload_file_id_for_upload_file_document(...)`,
`_get_upload_file_for_upload_file_document(...)`, `_get_upload_files_by_document_id_for_zip_download(...)`).
- ZIP packing is handled by `FileService.build_upload_files_zip_tempfile(...)`, which also:
- sanitizes entry names to avoid path traversal, and
- deduplicates names while preserving extensions (e.g., `doc.txt``doc (1).txt`).
Streaming the response and deferring cleanup is handled by the route via `send_file(path, ...)` + `ExitStack` +
`response.call_on_close(...)` (the file is deleted when the response is closed).

View File

@@ -0,0 +1,18 @@
## Purpose
`api/services/dataset_service.py` hosts dataset/document service logic used by console and API controllers.
## Batch document operations
- Batch document workflows should avoid N+1 database queries by using set-based lookups.
- Tenant checks must be enforced consistently across dataset/document operations.
- `DocumentService.get_documents_by_ids(...)` fetches documents for a dataset using `id.in_(...)`.
- `FileService.get_upload_files_by_ids(...)` performs tenant-scoped batch lookup for `UploadFile` (dedupes ids with `set(...)`).
- `DocumentService.get_document_download_url(...)` and `prepare_document_batch_download_zip(...)` handle
dataset/document permission checks plus `Document -> UploadFile` validation for download endpoints.
## Verification plan
- Exercise document list and download endpoints that use the service helpers.
- Confirm batch download uses constant query count for documents + upload files.
- Request a ZIP with a missing document id and confirm a 404 is returned.

View File

@@ -0,0 +1,35 @@
## Purpose
`api/services/file_service.py` owns business logic around `UploadFile` objects: upload validation, storage persistence,
previews/generators, and deletion.
## Key invariants
- All storage I/O goes through `extensions.ext_storage.storage`.
- Uploaded file keys follow: `upload_files/<tenant_id>/<uuid>.<ext>`.
- Upload validation is enforced in `FileService.upload_file(...)` (blocked extensions, size limits, dataset-only types).
## Batch lookup helpers
- `FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)` is the canonical tenant-scoped batch loader for
`UploadFile`.
## Dataset document download helpers
The dataset document download/ZIP endpoints now delegate “Document → UploadFile” validation and permission checks to
`DocumentService` (`api/services/dataset_service.py`). `FileService` stays focused on generic `UploadFile` operations
(uploading, previews, deletion), plus generic ZIP serving.
### ZIP serving
- `FileService.build_upload_files_zip_tempfile(...)` builds a ZIP from `UploadFile` objects and yields a seeked
tempfile **path** so callers can stream it (e.g., `send_file(path, ...)`) without hitting "read of closed file"
issues from file-handle lifecycle during streamed responses.
- Flask `send_file(...)` and the `ExitStack`/`call_on_close(...)` cleanup pattern are handled in the route layer.
## Verification plan
- Unit: `api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py`
- Verify signed URL generation for upload-file documents and ZIP download behavior for multiple documents.
- Unit: `api/tests/unit_tests/services/test_file_service_zip_and_lookup.py`
- Verify ZIP packing produces a valid, openable archive and preserves file content.

View File

@@ -0,0 +1,28 @@
## Purpose
Unit tests for the console dataset document download endpoint:
- `GET /datasets/<dataset_id>/documents/<document_id>/download`
## Testing approach
- Uses `Flask.test_request_context()` and calls the `Resource.get(...)` method directly.
- Monkeypatches console decorators (`login_required`, `setup_required`, rate limit) to no-ops to keep the test focused.
- Mocks:
- `DatasetService.get_dataset` / `check_dataset_permission`
- `DocumentService.get_document` for single-file download tests
- `DocumentService.get_documents_by_ids` + `FileService.get_upload_files_by_ids` for ZIP download tests
- `FileService.get_upload_files_by_ids` for `UploadFile` lookups in single-file tests
- `services.dataset_service.file_helpers.get_signed_file_url` to return a deterministic URL
- Document mocks include `id` fields so batch lookups can map documents by id.
## Covered cases
- Success returns `{ "url": "<signed>" }` for upload-file documents.
- 404 when document is not `upload_file`.
- 404 when `upload_file_id` is missing.
- 404 when referenced `UploadFile` row does not exist.
- 403 when document tenant does not match current tenant.
- Batch ZIP download returns `application/zip` for upload-file documents.
- Batch ZIP download rejects non-upload-file documents.
- Batch ZIP download uses a random `.zip` attachment name (`download_name`), so tests only assert the suffix.

View File

@@ -0,0 +1,18 @@
## Purpose
Unit tests for `api/services/file_service.py` helper methods that are not covered by higher-level controller tests.
## Whats covered
- `FileService.build_upload_files_zip_tempfile(...)`
- ZIP entry name sanitization (no directory components / traversal)
- name deduplication while preserving extensions
- writing streamed bytes from `storage.load(...)` into ZIP entries
- yields a tempfile path so callers can open/stream the ZIP without holding a live file handle
- `FileService.get_upload_files_by_ids(...)`
- returns `{}` for empty id lists
- returns an id-keyed mapping for non-empty lists
## Notes
- These tests intentionally stub `storage.load` and `db.session.scalars(...).all()` to avoid needing a real DB/storage.

View File

@@ -1,3 +1,4 @@
import os
import sys
@@ -8,10 +9,15 @@ def is_db_command() -> bool:
# create app
flask_app = None
socketio_app = None
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
socketio_app = app
flask_app = app
else:
# Gunicorn and Celery handle monkey patching automatically in production by
# specifying the `gevent` worker class. Manual monkey patching is not required here.
@@ -22,8 +28,15 @@ else:
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]
socketio_app, flask_app = create_app()
app = flask_app
celery = flask_app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 5001))
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@@ -1,6 +1,7 @@
import logging
import time
import socketio # type: ignore[reportMissingTypeStubs]
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
@@ -8,6 +9,7 @@ from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
from extensions.ext_socketio import sio
logger = logging.getLogger(__name__)
@@ -60,14 +62,18 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> DifyApp:
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
return socketio_app, app
def initialize_extensions(app: DifyApp):

View File

@@ -1219,6 +1219,13 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class CollaborationConfig(BaseSettings):
ENABLE_COLLABORATION_MODE: bool = Field(
description="Whether to enable collaboration mode features across the workspace",
default=False,
)
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
@@ -1333,6 +1340,7 @@ class FeatureConfig(
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
CollaborationConfig,
LoginConfig,
AccountConfig,
SwaggerUIConfig,

View File

@@ -63,6 +63,7 @@ from .app import (
statistic,
workflow,
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_run,
workflow_statistic,
@@ -112,6 +113,7 @@ from .explore import (
recommended_app,
saved_message,
)
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
# Import tag controllers
from .tag import tags
@@ -203,6 +205,7 @@ __all__ = [
"website",
"workflow",
"workflow_app_log",
"workflow_comment",
"workflow_draft_variable",
"workflow_run",
"workflow_statistic",

View File

@@ -1,4 +1,3 @@
import re
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@@ -68,48 +67,6 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
# XSS prevention: patterns that could lead to XSS attacks
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
_XSS_PATTERNS = [
r"<script[^>]*>.*?</script>", # Script tags
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
r"javascript:", # JavaScript protocol
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
r"<embed[^>]*>", # Embed tags (self-closing)
r"<link[^>]*>", # Link tags with javascript
]
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
"""
Validate that a string value doesn't contain potential XSS payloads.
Args:
value: The string value to validate
field_name: Name of the field for error messages
Returns:
The original value if safe
Raises:
ValueError: If the value contains XSS patterns
"""
if value is None:
return None
value_lower = value.lower()
for pattern in _XSS_PATTERNS:
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
raise ValueError(
f"{field_name} contains invalid characters or patterns. "
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
)
return value
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@@ -118,11 +75,6 @@ class CreateAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
@@ -133,11 +85,6 @@ class UpdateAppPayload(BaseModel):
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
@@ -146,11 +93,6 @@ class CopyAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@@ -55,35 +55,6 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
class ContextGeneratePayload(BaseModel):
"""Payload for generating extractor code node."""
workflow_id: str = Field(..., description="Workflow ID")
node_id: str = Field(..., description="Current tool/llm node ID")
parameter_name: str = Field(..., description="Parameter name to generate code for")
language: str = Field(default="python3", description="Code language (python3/javascript)")
prompt_messages: list[dict[str, Any]] = Field(
..., description="Multi-turn conversation history, last message is the current instruction"
)
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class SuggestedQuestionsPayload(BaseModel):
"""Payload for generating suggested questions."""
workflow_id: str = Field(..., description="Workflow ID")
node_id: str = Field(..., description="Current tool/llm node ID")
parameter_name: str = Field(..., description="Parameter name")
language: str = Field(
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
)
model_config_data: dict[str, Any] | None = Field(
default=None,
alias="model_config",
description="Model configuration (optional, uses system default if not provided)",
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -93,8 +64,6 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ContextGeneratePayload)
reg(SuggestedQuestionsPayload)
@console_ns.route("/rule-generate")
@@ -309,74 +278,3 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args.type}")
@console_ns.route("/context-generate")
class ContextGenerateApi(Resource):
@console_ns.doc("generate_with_context")
@console_ns.doc(description="Generate with multi-turn conversation context")
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
@console_ns.response(200, "Content generated successfully")
@console_ns.response(400, "Invalid request parameters or workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
from core.llm_generator.utils import deserialize_prompt_messages
args = ContextGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
try:
return LLMGenerator.generate_with_context(
tenant_id=current_tenant_id,
workflow_id=args.workflow_id,
node_id=args.node_id,
parameter_name=args.parameter_name,
language=args.language,
prompt_messages=prompt_messages,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
@console_ns.route("/context-generate/suggested-questions")
class SuggestedQuestionsApi(Resource):
@console_ns.doc("generate_suggested_questions")
@console_ns.doc(description="Generate suggested questions for context generation")
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
@console_ns.response(200, "Questions generated successfully")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
return LLMGenerator.generate_suggested_questions(
tenant_id=current_tenant_id,
workflow_id=args.workflow_id,
node_id=args.node_id,
parameter_name=args.parameter_name,
language=args.language,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)

View File

@@ -32,8 +32,10 @@ from core.trigger.debug.event_selectors import (
from core.workflow.enums import NodeType
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@@ -43,11 +45,10 @@ from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from models.workflow import Workflow
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.workflow.entities import MentionGraphRequest, MentionParameterSchema
from services.workflow.mention_graph_service import MentionGraphService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@@ -182,6 +183,14 @@ class WorkflowUpdatePayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
class WorkflowFeaturesPayload(BaseModel):
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
class WorkflowOnlineUsersQuery(BaseModel):
workflow_ids: str = Field(..., description="Comma-separated workflow IDs")
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
@@ -190,15 +199,6 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
class MentionGraphPayload(BaseModel):
"""Request payload for generating mention graph."""
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
parameter_key: str = Field(description="Key of the parameter being extracted")
context_source: list[str] = Field(description="Variable selector for the context source")
parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -214,9 +214,10 @@ reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowFeaturesPayload)
reg(WorkflowOnlineUsersQuery)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
reg(MentionGraphPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@@ -803,6 +804,31 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
class WorkflowFeaturesApi(Resource):
"""Update draft workflow features."""
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
@console_ns.doc("update_workflow_features")
@console_ns.doc(description="Update draft workflow features")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Workflow features updated successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
features = args.features
workflow_service = WorkflowService()
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@@ -1180,52 +1206,30 @@ class DraftWorkflowTriggerRunAllApi(Resource):
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/mention-graph")
class MentionGraphApi(Resource):
"""
API for generating Mention LLM node graph structures.
This endpoint creates a complete graph structure containing an LLM node
configured to extract values from list[PromptMessage] variables.
"""
@console_ns.doc("generate_mention_graph")
@console_ns.doc(description="Generate a Mention LLM node graph structure")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MentionGraphPayload.__name__])
@console_ns.response(200, "Mention graph generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied")
@console_ns.route("/apps/workflows/online-users")
class WorkflowOnlineUsersApi(Resource):
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
@console_ns.doc("get_workflow_online_users")
@console_ns.doc(description="Get workflow online users")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Generate a Mention LLM node graph structure.
@marshal_with(online_user_list_fields)
def get(self):
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
Returns a complete graph structure containing a single LLM node
configured for extracting values from list[PromptMessage] context.
"""
workflow_ids = [workflow_id.strip() for workflow_id in args.workflow_ids.split(",") if workflow_id.strip()]
payload = MentionGraphPayload.model_validate(console_ns.payload or {})
results = []
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}")
parameter_schema = MentionParameterSchema(
name=payload.parameter_schema.get("name", payload.parameter_key),
type=payload.parameter_schema.get("type", "string"),
description=payload.parameter_schema.get("description", ""),
)
users = []
for _, user_info_json in users_json.items():
try:
users.append(json.loads(user_info_json))
except Exception:
continue
results.append({"workflow_id": workflow_id, "users": users})
request = MentionGraphRequest(
parent_node_id=payload.parent_node_id,
parameter_key=payload.parameter_key,
context_source=payload.context_source,
parameter_schema=parameter_schema,
)
with Session(db.engine) as session:
service = MentionGraphService(session)
response = service.generate_mention_graph(tenant_id=app_model.tenant_id, request=request)
return response.model_dump()
return {"data": results}

View File

@@ -0,0 +1,317 @@
import logging
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import account_with_role_fields
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowCommentCreatePayload(BaseModel):
position_x: float = Field(..., description="Comment X position")
position_y: float = Field(..., description="Comment Y position")
content: str = Field(..., description="Comment content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentUpdatePayload(BaseModel):
content: str = Field(..., description="Comment content")
position_x: float | None = Field(default=None, description="Comment X position")
position_y: float | None = Field(default=None, description="Comment Y position")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyCreatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyUpdatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
for model in (
WorkflowCommentCreatePayload,
WorkflowCommentUpdatePayload,
WorkflowCommentReplyCreatePayload,
WorkflowCommentReplyUpdatePayload,
):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
workflow_comment_reply_create_model = console_ns.model(
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
)
workflow_comment_reply_update_model = console_ns.model(
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
)
workflow_comment_mention_users_model = console_ns.model(
"WorkflowCommentMentionUsers",
{"users": fields.List(fields.Nested(account_with_role_fields))},
)
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@console_ns.doc("list_workflow_comments")
@console_ns.doc(description="Get all comments for a workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_basic_model, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@console_ns.doc("create_workflow_comment")
@console_ns.doc(description="Create a new workflow comment")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_create_model)
def post(self, app_model: App):
"""Create a new workflow comment."""
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@console_ns.doc("get_workflow_comment")
@console_ns.doc(description="Get a specific workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_detail_model)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@console_ns.doc("update_workflow_comment")
@console_ns.doc(description="Update a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_update_model)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result
@console_ns.doc("delete_workflow_comment")
@console_ns.doc(description="Delete a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(204, "Comment deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@console_ns.doc("resolve_workflow_comment")
@console_ns.doc(description="Resolve a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_resolve_model)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@console_ns.doc("create_workflow_comment_reply")
@console_ns.doc(description="Add a reply to a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_create_model)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=payload.content,
created_by=current_user.id,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@console_ns.doc("update_workflow_comment_reply")
@console_ns.doc(description="Update a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_update_model)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
reply = WorkflowCommentService.update_reply(
reply_id=reply_id,
user_id=current_user.id,
content=payload.content,
mentioned_user_ids=payload.mentioned_user_ids,
)
return reply
@console_ns.doc("delete_workflow_comment_reply")
@console_ns.doc(description="Delete a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.response(204, "Reply deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@console_ns.doc("workflow_comment_mention_users")
@console_ns.doc(description="Get all users in current tenant for mentions")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_mention_users_model)
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"users": members}

View File

@@ -17,13 +17,13 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import login_required
from libs.login import current_user, login_required
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@@ -43,6 +43,16 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
value: Any | None = Field(default=None, description="Variable value")
class ConversationVariableUpdatePayload(BaseModel):
conversation_variables: list[dict[str, Any]] = Field(
..., description="Conversation variables for the draft workflow"
)
class EnvironmentVariableUpdatePayload(BaseModel):
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
@@ -51,6 +61,14 @@ console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ConversationVariableUpdatePayload.__name__,
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EnvironmentVariableUpdatePayload.__name__,
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment):
@@ -58,8 +76,6 @@ def _convert_values_to_json_serializable_object(value: Segment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, ArrayPromptMessageSegment):
return value.to_object()
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
@@ -385,7 +401,7 @@ class VariableApi(Resource):
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@@ -478,6 +494,34 @@ class ConversationVariableCollectionApi(Resource):
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
@console_ns.doc("update_conversation_variables")
@console_ns.doc(description="Update conversation variables for workflow draft")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Conversation variables updated successfully")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def post(self, app_model: App):
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
conversation_variables_list = payload.conversation_variables
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service.update_draft_workflow_conversation_variables(
app_model=app_model,
account=current_user,
conversation_variables=conversation_variables,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@@ -529,3 +573,31 @@ class EnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
@console_ns.doc("update_environment_variables")
@console_ns.doc(description="Update environment variables for workflow draft")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Environment variables updated successfully")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
environment_variables_list = payload.environment_variables
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
workflow_service.update_draft_workflow_environment_variables(
app_model=app_model,
account=current_user,
environment_variables=environment_variables,
)
return {"result": "success"}

View File

@@ -2,10 +2,12 @@ import json
import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast
from contextlib import ExitStack
from typing import Any, Literal, cast
from uuid import UUID
import sqlalchemy as sa
from flask import request
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
@@ -42,6 +44,7 @@ from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@@ -65,6 +68,9 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# NOTE: Keep constants near the top of the module for discoverability.
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
@@ -104,6 +110,12 @@ class DocumentRenamePayload(BaseModel):
name: str
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
@@ -120,6 +132,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
DocumentBatchDownloadZipPayload,
)
@@ -853,6 +866,62 @@ class DocumentApi(DocumentResource):
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
class DocumentDownloadApi(DocumentResource):
"""Return a signed download URL for a dataset document's original uploaded file."""
@console_ns.doc("get_dataset_document_download_url")
@console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
return {"url": DocumentService.get_document_download_url(document)}
@console_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
class DocumentBatchDownloadZipApi(DocumentResource):
"""Download multiple uploaded-file documents as a single ZIP (avoids browser multi-download limits)."""
@console_ns.doc("download_dataset_documents_as_zip")
@console_ns.doc(description="Download selected dataset documents as a single ZIP archive (upload-file only)")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: str):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
dataset_id=dataset_id,
document_ids=document_ids,
tenant_id=current_tenant_id,
current_user=current_user,
)
# Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route.
with ExitStack() as stack:
zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
response = send_file(
zip_path,
mimetype="application/zip",
as_attachment=True,
download_name=download_name,
)
cleanup = stack.pop_all()
response.call_on_close(cleanup.close)
return response
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
class DocumentProcessingApi(DocumentResource):
@console_ns.doc("update_document_processing")

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,112 @@
import logging
from collections.abc import Callable
from typing import cast
from flask import Request as FlaskRequest
from extensions.ext_socketio import sio
from libs.passport import PassportService
from libs.token import extract_access_token
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
from services.account_service import AccountService
from services.workflow_collaboration_service import WorkflowCollaborationService
repository = WorkflowCollaborationRepository()
collaboration_service = WorkflowCollaborationService(repository, sio)
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
@_sio_on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
try:
request_environ = FlaskRequest(environ)
token = extract_access_token(request_environ)
except Exception:
logging.exception("Failed to extract token")
token = None
if not token:
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
logging.warning(
"Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid
)
return False
if not user.has_edit_permission:
logging.warning(
"Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid
)
return False
collaboration_service.save_session(sid, user)
return True
except Exception:
logging.exception("Socket authentication failed")
return False
@_sio_on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
result = collaboration_service.register_session(workflow_id, sid)
if not result:
return {"msg": "unauthorized"}, 401
user_id, is_leader = result
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@_sio_on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
collaboration_service.disconnect_session(sid)
@_sio_on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouse_move
2. vars_and_features_update
3. sync_request (ask leader to update graph)
4. app_state_update
5. mcp_server_update
6. workflow_update
7. comments_update
8. node_panel_presence
"""
return collaboration_service.relay_collaboration_event(sid, data)
@_sio_on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
return collaboration_service.relay_graph_event(sid, data)

View File

@@ -36,6 +36,7 @@ from controllers.console.wraps import (
only_edition_cloud,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
@@ -73,6 +74,10 @@ class AccountAvatarPayload(BaseModel):
avatar: str
class AccountAvatarQuery(BaseModel):
avatar: str = Field(..., description="Avatar file ID")
class AccountInterfaceLanguagePayload(BaseModel):
interface_language: str
@@ -158,6 +163,7 @@ def reg(cls: type[BaseModel]):
reg(AccountInitPayload)
reg(AccountNamePayload)
reg(AccountAvatarPayload)
reg(AccountAvatarQuery)
reg(AccountInterfaceLanguagePayload)
reg(AccountInterfaceThemePayload)
reg(AccountTimezonePayload)
@@ -248,6 +254,18 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
@console_ns.doc("get_account_avatar")
@console_ns.doc(description="Get account avatar url")
@setup_required
@login_required
@account_initialization_required
def get(self):
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
avatar_url = file_helpers.get_signed_file_url(args.avatar)
return {"avatar_url": avatar_url}
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required

View File

@@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -70,8 +70,6 @@ class _NodeSnapshot:
"""Empty string means the node is not executing inside an iteration."""
loop_id: str = ""
"""Empty string means the node is not executing inside a loop."""
mention_parent_id: str = ""
"""Empty string means the node is not an extractor node."""
class WorkflowResponseConverter:
@@ -133,7 +131,6 @@ class WorkflowResponseConverter:
start_at=event.start_at,
iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "",
mention_parent_id=event.in_mention_parent_id or "",
)
node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot
@@ -290,7 +287,6 @@ class WorkflowResponseConverter:
created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
agent_strategy=event.agent_strategy,
),
)
@@ -377,7 +373,6 @@ class WorkflowResponseConverter:
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
),
)
@@ -427,7 +422,6 @@ class WorkflowResponseConverter:
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
retry_index=event.retry_index,
),
)

View File

@@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
@@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -385,7 +385,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@@ -406,7 +405,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
@@ -430,7 +428,6 @@ class WorkflowBasedAppRunner:
execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
@@ -447,7 +444,6 @@ class WorkflowBasedAppRunner:
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
@@ -464,7 +460,6 @@ class WorkflowBasedAppRunner:
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
@@ -474,7 +469,6 @@ class WorkflowBasedAppRunner:
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
@@ -483,7 +477,6 @@ class WorkflowBasedAppRunner:
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunAgentLogEvent):

View File

@@ -190,8 +190,6 @@ class QueueTextChunkEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
class QueueAgentMessageEvent(AppQueueEvent):
@@ -231,8 +229,6 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
class QueueAnnotationReplyEvent(AppQueueEvent):
@@ -310,8 +306,6 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_run_index: int = 1 # FIXME(-LAN-): may not used
in_iteration_id: str | None = None
in_loop_id: str | None = None
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
agent_strategy: AgentNodeStrategyInit | None = None
@@ -334,8 +328,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict)
@@ -391,8 +383,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict)
@@ -417,8 +407,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict)

View File

@@ -262,7 +262,6 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict[str, object] = Field(default_factory=dict)
iteration_id: str | None = None
loop_id: str | None = None
mention_parent_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
@@ -286,7 +285,6 @@ class NodeStartStreamResponse(StreamResponse):
"extras": {},
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
},
}
@@ -322,7 +320,6 @@ class NodeFinishStreamResponse(StreamResponse):
files: Sequence[Mapping[str, Any]] | None = []
iteration_id: str | None = None
loop_id: str | None = None
mention_parent_id: str | None = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@@ -352,7 +349,6 @@ class NodeFinishStreamResponse(StreamResponse):
"files": [],
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
},
}
@@ -388,7 +384,6 @@ class NodeRetryStreamResponse(StreamResponse):
files: Sequence[Mapping[str, Any]] | None = []
iteration_id: str | None = None
loop_id: str | None = None
mention_parent_id: str | None = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
@@ -419,7 +414,6 @@ class NodeRetryStreamResponse(StreamResponse):
"files": [],
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
"retry_index": self.data.retry_index,
},
}

View File

@@ -1,5 +1,4 @@
import base64
import logging
from collections.abc import Mapping
from configs import dify_config
@@ -11,10 +10,7 @@ from core.model_runtime.entities import (
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import (
MultiModalPromptMessageContent,
PromptMessageContentUnionTypes,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
@@ -22,8 +18,6 @@ from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
def get_attr(*, file: File, attr: FileAttribute):
match attr:
@@ -95,8 +89,6 @@ def to_prompt_message_content(
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
"file_ref": _encode_file_ref(f),
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
@@ -104,17 +96,6 @@ def to_prompt_message_content(
return prompt_class_map[f.type].model_validate(params)
def _encode_file_ref(f: File) -> str | None:
"""Encode file reference as 'transfer_method:id_or_url' string."""
if f.transfer_method == FileTransferMethod.REMOTE_URL:
return f"remote:{f.remote_url}" if f.remote_url else None
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
return f"local:{f.related_id}" if f.related_id else None
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
return f"tool:{f.related_id}" if f.related_id else None
return None
def download(f: File, /):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
@@ -183,128 +164,3 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def restore_multimodal_content(
content: MultiModalPromptMessageContent,
) -> MultiModalPromptMessageContent:
"""
Restore base64_data or url for multimodal content from file_ref.
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
Args:
content: MultiModalPromptMessageContent with file_ref field
Returns:
MultiModalPromptMessageContent with restored base64_data or url
"""
# Skip if no file reference or content already has data
if not content.file_ref:
return content
if content.base64_data or content.url:
return content
try:
file = _build_file_from_ref(
file_ref=content.file_ref,
file_format=content.format,
mime_type=content.mime_type,
filename=content.filename,
)
if not file:
return content
# Restore content based on config
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
restored_base64 = _get_encoded_string(file)
return content.model_copy(update={"base64_data": restored_base64})
else:
restored_url = _to_url(file)
return content.model_copy(update={"url": restored_url})
except Exception as e:
logger.warning("Failed to restore multimodal content: %s", e)
return content
def _build_file_from_ref(
file_ref: str,
file_format: str | None,
mime_type: str | None,
filename: str | None,
) -> File | None:
"""
Build a File object from encoded file_ref string.
Args:
file_ref: Encoded reference "transfer_method:id_or_url"
file_format: The file format/extension (without dot)
mime_type: The mime type
filename: The filename
Returns:
File object with storage_key loaded, or None if not found
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import UploadFile
from models.tools import ToolFile
# Parse file_ref: "method:value"
if ":" not in file_ref:
logger.warning("Invalid file_ref format: %s", file_ref)
return None
method, value = file_ref.split(":", 1)
extension = f".{file_format}" if file_format else None
if method == "remote":
return File(
tenant_id="",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=value,
extension=extension,
mime_type=mime_type,
filename=filename,
storage_key="",
)
# Query database for storage_key
with Session(db.engine) as session:
if method == "local":
stmt = select(UploadFile).where(UploadFile.id == value)
upload_file = session.scalar(stmt)
if upload_file:
return File(
tenant_id=upload_file.tenant_id,
type=FileType(upload_file.extension)
if hasattr(FileType, upload_file.extension.upper())
else FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id=value,
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
mime_type=mime_type or upload_file.mime_type,
filename=filename or upload_file.name,
storage_key=upload_file.key,
)
elif method == "tool":
stmt = select(ToolFile).where(ToolFile.id == value)
tool_file = session.scalar(stmt)
if tool_file:
return File(
tenant_id=tool_file.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=value,
extension=extension,
mime_type=mime_type or tool_file.mimetype,
filename=filename or tool_file.name,
storage_key=tool_file.file_key,
)
logger.warning("File not found for file_ref: %s", file_ref)
return None

View File

@@ -1,16 +1,11 @@
import json
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any, Protocol, cast
from collections.abc import Sequence
from typing import Protocol, cast
import json_repair
from core.llm_generator.output_models import (
CodeNodeStructuredOutput,
InstructionModifyOutput,
SuggestedQuestionsOutput,
)
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@@ -398,432 +393,6 @@ class LLMGenerator:
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_with_context(
cls,
tenant_id: str,
workflow_id: str,
node_id: str,
parameter_name: str,
language: str,
prompt_messages: list[PromptMessage],
model_config: dict,
) -> dict:
"""
Generate extractor code node based on conversation context.
Args:
tenant_id: Tenant/workspace ID
workflow_id: Workflow ID
node_id: Current tool/llm node ID
parameter_name: Parameter name to generate code for
language: Code language (python3/javascript)
prompt_messages: Multi-turn conversation history (last message is instruction)
model_config: Model configuration (provider, name, completion_params)
Returns:
dict with CodeNodeData format:
- variables: Input variable selectors
- code_language: Code language
- code: Generated code
- outputs: Output definitions
- message: Explanation
- error: Error message if any
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from services.workflow_service import WorkflowService
# Get workflow
with Session(db.engine) as session:
stmt = select(App).where(App.id == workflow_id)
app = session.scalar(stmt)
if not app:
return cls._error_response(f"App {workflow_id} not found")
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return cls._error_response(f"Workflow for app {workflow_id} not found")
# Get upstream nodes via edge backtracking
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
# Get current node info
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
if not current_node:
return cls._error_response(f"Node {node_id} not found")
# Get parameter info
parameter_info = cls._get_parameter_info(
tenant_id=tenant_id,
node_data=current_node.get("data", {}),
parameter_name=parameter_name,
)
# Build system prompt
system_prompt = cls._build_extractor_system_prompt(
upstream_nodes=upstream_nodes,
current_node=current_node,
parameter_info=parameter_info,
language=language,
)
# Construct complete prompt_messages with system prompt
complete_messages: list[PromptMessage] = [
SystemPromptMessage(content=system_prompt),
*prompt_messages,
]
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
# Get model instance and schema
provider = model_config.get("provider", "")
model_name = model_config.get("name", "")
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=provider,
model=model_name,
)
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
if not model_schema:
return cls._error_response(f"Model schema not found for {model_name}")
model_parameters = model_config.get("completion_params", {})
try:
response = invoke_llm_with_pydantic_model(
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=complete_messages,
output_model=CodeNodeStructuredOutput,
model_parameters=model_parameters,
stream=False,
tenant_id=tenant_id,
)
return cls._parse_code_node_output(
response.structured_output, language, parameter_info.get("type", "string")
)
except InvokeError as e:
return cls._error_response(str(e))
except Exception as e:
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
return cls._error_response(f"An unexpected error occurred: {str(e)}")
@classmethod
def _error_response(cls, error: str) -> dict:
"""Return error response in CodeNodeData format."""
return {
"variables": [],
"code_language": "python3",
"code": "",
"outputs": {},
"message": "",
"error": error,
}
@classmethod
def generate_suggested_questions(
cls,
tenant_id: str,
workflow_id: str,
node_id: str,
parameter_name: str,
language: str,
model_config: dict | None = None,
) -> dict:
"""
Generate suggested questions for context generation.
Returns dict with questions array and error field.
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
from services.workflow_service import WorkflowService
# Get workflow context (reuse existing logic)
with Session(db.engine) as session:
stmt = select(App).where(App.id == workflow_id)
app = session.scalar(stmt)
if not app:
return {"questions": [], "error": f"App {workflow_id} not found"}
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
if not current_node:
return {"questions": [], "error": f"Node {node_id} not found"}
parameter_info = cls._get_parameter_info(
tenant_id=tenant_id,
node_data=current_node.get("data", {}),
parameter_name=parameter_name,
)
# Build prompt
system_prompt = cls._build_suggested_questions_prompt(
upstream_nodes=upstream_nodes,
current_node=current_node,
parameter_info=parameter_info,
language=language,
)
prompt_messages: list[PromptMessage] = [
SystemPromptMessage(content=system_prompt),
]
# Get model instance - use default if model_config not provided
model_manager = ModelManager()
if model_config:
provider = model_config.get("provider", "")
model_name = model_config.get("name", "")
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=provider,
model=model_name,
)
else:
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
model_name = model_instance.model
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
if not model_schema:
return {"questions": [], "error": f"Model schema not found for {model_name}"}
completion_params = model_config.get("completion_params", {}) if model_config else {}
model_parameters = {**completion_params, "max_tokens": 256}
try:
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
output_model=SuggestedQuestionsOutput,
model_parameters=model_parameters,
stream=False,
tenant_id=tenant_id,
)
questions = response.structured_output.get("questions", []) if response.structured_output else []
return {"questions": questions, "error": ""}
except InvokeError as e:
return {"questions": [], "error": str(e)}
except Exception as e:
logger.exception("Failed to generate suggested questions, model: %s", model_name)
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def _build_suggested_questions_prompt(
cls,
upstream_nodes: list[dict],
current_node: dict,
parameter_info: dict,
language: str = "English",
) -> str:
"""Build minimal prompt for suggested questions generation."""
# Simplify upstream nodes to reduce tokens
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
param_type = parameter_info.get("type", "string")
param_desc = parameter_info.get("description", "")[:100]
return f"""Suggest 3 code generation questions for extracting data.
Sources: {", ".join(sources)}
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
Output 3 short, practical questions in {language}."""
@classmethod
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
"""
Get all upstream nodes via edge backtracking.
Traverses the graph backwards from node_id to collect all reachable nodes.
"""
from collections import defaultdict
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
edges = graph_dict.get("edges", [])
# Build reverse adjacency list
reverse_adj: dict[str, list[str]] = defaultdict(list)
for edge in edges:
reverse_adj[edge["target"]].append(edge["source"])
# BFS to find all upstream nodes
visited: set[str] = set()
queue = [node_id]
upstream: list[dict] = []
while queue:
current = queue.pop(0)
for source in reverse_adj.get(current, []):
if source not in visited:
visited.add(source)
queue.append(source)
if source in nodes:
upstream.append(cls._extract_node_info(nodes[source]))
return upstream
@classmethod
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
"""Get node by ID from graph."""
for node in graph_dict.get("nodes", []):
if node["id"] == node_id:
return node
return None
@classmethod
def _extract_node_info(cls, node: dict) -> dict:
"""Extract minimal node info with outputs based on node type."""
node_type = node["data"]["type"]
node_data = node.get("data", {})
# Build outputs based on node type (only type, no description to reduce tokens)
outputs: dict[str, str] = {}
match node_type:
case "start":
for var in node_data.get("variables", []):
name = var.get("variable", var.get("name", ""))
outputs[name] = var.get("type", "string")
case "llm":
outputs["text"] = "string"
case "code":
for name, output in node_data.get("outputs", {}).items():
outputs[name] = output.get("type", "string")
case "http-request":
outputs = {"body": "string", "status_code": "number", "headers": "object"}
case "knowledge-retrieval":
outputs["result"] = "array[object]"
case "tool":
outputs = {"text": "string", "json": "object"}
case _:
outputs["output"] = "string"
info: dict = {
"id": node["id"],
"title": node_data.get("title", node["id"]),
"outputs": outputs,
}
# Only include description if not empty
desc = node_data.get("desc", "")
if desc:
info["desc"] = desc
return info
@classmethod
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
"""Get parameter info from tool schema using ToolManager."""
default_info = {"name": parameter_name, "type": "string", "description": ""}
if node_data.get("type") != "tool":
return default_info
try:
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
provider_type_str = node_data.get("provider_type", "")
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
tool_runtime = ToolManager.get_tool_runtime(
provider_type=provider_type,
provider_id=node_data.get("provider_id", ""),
tool_name=node_data.get("tool_name", ""),
tenant_id=tenant_id,
invoke_from=InvokeFrom.DEBUGGER,
)
parameters = tool_runtime.get_merged_runtime_parameters()
for param in parameters:
if param.name == parameter_name:
return {
"name": param.name,
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
"description": param.llm_description
or (param.human_description.en_US if param.human_description else ""),
"required": param.required,
}
except Exception as e:
logger.debug("Failed to get parameter info from ToolManager: %s", e)
return default_info
@classmethod
def _build_extractor_system_prompt(
cls,
upstream_nodes: list[dict],
current_node: dict,
parameter_info: dict,
language: str,
) -> str:
"""Build system prompt for extractor code generation."""
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
param_type = parameter_info.get("type", "string")
return f"""You are a code generator for workflow automation.
Generate {language} code to extract/transform upstream node outputs for the target parameter.
## Upstream Nodes
{upstream_json}
## Target
Node: {current_node["data"].get("title", current_node["id"])}
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
## Requirements
- Write a main function that returns type: {param_type}
- Use value_selector format: ["node_id", "output_name"]
"""
@classmethod
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
"""
Parse structured output to CodeNodeData format.
Args:
content: Structured output dict from invoke_llm_with_structured_output
language: Code language
parameter_type: Expected parameter type
Returns dict with variables, code_language, code, outputs, message, error.
"""
if content is None:
return cls._error_response("Empty or invalid response from LLM")
# Validate and normalize variables
variables = [
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
for v in content.get("variables", [])
if isinstance(v, dict)
]
outputs = content.get("outputs", {"result": {"type": parameter_type}})
return {
"variables": variables,
"code_language": language,
"code": content.get("code", ""),
"outputs": outputs,
"message": content.get("explanation", ""),
"error": "",
}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
@@ -960,10 +529,6 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
model_name = model_config.get("name", "")
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
if not model_schema:
return {"error": f"Model schema not found for {model_name}"}
match node_type:
case "llm" | "agent":
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
@@ -987,18 +552,20 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
model_parameters = {"temperature": 0.4}
try:
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=list(prompt_messages),
output_model=InstructionModifyOutput,
model_parameters=model_parameters,
stream=False,
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
return response.structured_output or {}
generated_raw = response.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
return data
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}

View File

@@ -1,34 +0,0 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
from core.variables.types import SegmentType
from core.workflow.nodes.base.entities import VariableSelector
class SuggestedQuestionsOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
questions: list[str] = Field(min_length=3, max_length=3)
class CodeNodeOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
type: SegmentType
class CodeNodeStructuredOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
variables: list[VariableSelector]
code: str
outputs: dict[str, CodeNodeOutput]
explanation: str
class InstructionModifyOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
modified: str
message: str

View File

@@ -1,188 +0,0 @@
"""
File reference detection and conversion for structured output.
This module provides utilities to:
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
2. Convert file ID strings to File objects after LLM returns
"""
import uuid
from collections.abc import Mapping
from typing import Any
from core.file import File
from core.variables.segments import ArrayFileSegment, FileSegment
from factories.file_factory import build_from_mapping
FILE_REF_FORMAT = "dify-file-ref"
def is_file_ref_property(schema: dict) -> bool:
"""Check if a schema property is a file reference."""
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
"""
Recursively detect file reference fields in schema.
Args:
schema: JSON Schema to analyze
path: Current path in the schema (used for recursion)
Returns:
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
"""
file_ref_paths: list[str] = []
schema_type = schema.get("type")
if schema_type == "object":
for prop_name, prop_schema in schema.get("properties", {}).items():
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_ref_property(prop_schema):
file_ref_paths.append(current_path)
elif isinstance(prop_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
elif schema_type == "array":
items_schema = schema.get("items", {})
array_path = f"{path}[*]" if path else "[*]"
if is_file_ref_property(items_schema):
file_ref_paths.append(array_path)
elif isinstance(items_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
return file_ref_paths
def convert_file_refs_in_output(
output: Mapping[str, Any],
json_schema: Mapping[str, Any],
tenant_id: str,
) -> dict[str, Any]:
"""
Convert file ID strings to File objects based on schema.
Args:
output: The structured_output from LLM result
json_schema: The original JSON schema (to detect file ref fields)
tenant_id: Tenant ID for file lookup
Returns:
Output with file references converted to File objects
"""
file_ref_paths = detect_file_ref_fields(json_schema)
if not file_ref_paths:
return dict(output)
result = _deep_copy_dict(output)
for path in file_ref_paths:
_convert_path_in_place(result, path.split("."), tenant_id)
return result
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
"""Deep copy a mapping to a mutable dict."""
result: dict[str, Any] = {}
for key, value in obj.items():
if isinstance(value, Mapping):
result[key] = _deep_copy_dict(value)
elif isinstance(value, list):
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
else:
result[key] = value
return result
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
"""Convert file refs at the given path in place, wrapping in Segment types."""
if not path_parts:
return
current = path_parts[0]
remaining = path_parts[1:]
# Handle array notation like "files[*]"
if current.endswith("[*]"):
key = current[:-3] if current != "[*]" else None
target = obj.get(key) if key else obj
if isinstance(target, list):
if remaining:
# Nested array with remaining path - recurse into each item
for item in target:
if isinstance(item, dict):
_convert_path_in_place(item, remaining, tenant_id)
else:
# Array of file IDs - convert all and wrap in ArrayFileSegment
files: list[File] = []
for item in target:
file = _convert_file_id(item, tenant_id)
if file is not None:
files.append(file)
# Replace the array with ArrayFileSegment
if key:
obj[key] = ArrayFileSegment(value=files)
return
if not remaining:
# Leaf node - convert the value and wrap in FileSegment
if current in obj:
file = _convert_file_id(obj[current], tenant_id)
if file is not None:
obj[current] = FileSegment(value=file)
else:
obj[current] = None
else:
# Recurse into nested object
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, tenant_id)
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
"""
Convert a file ID string to a File object.
Tries multiple file sources in order:
1. ToolFile (files generated by tools/workflows)
2. UploadFile (files uploaded by users)
"""
if not isinstance(file_id, str):
return None
# Validate UUID format
try:
uuid.UUID(file_id)
except ValueError:
return None
# Try ToolFile first (files generated by tools/workflows)
try:
return build_from_mapping(
mapping={
"transfer_method": "tool_file",
"tool_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# Try UploadFile (files uploaded by users)
try:
return build_from_mapping(
mapping={
"transfer_method": "local_file",
"upload_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# File not found in any source
return None

View File

@@ -2,13 +2,12 @@ import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import Any, Literal, TypeVar, cast, overload
from typing import Any, Literal, cast, overload
import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import TypeAdapter, ValidationError
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
@@ -44,9 +43,6 @@ class SpecialModelType(StrEnum):
OLLAMA = "ollama"
T = TypeVar("T", bound=BaseModel)
@overload
def invoke_llm_with_structured_output(
*,
@@ -61,7 +57,6 @@ def invoke_llm_with_structured_output(
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
def invoke_llm_with_structured_output(
@@ -77,7 +72,6 @@ def invoke_llm_with_structured_output(
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
@@ -93,7 +87,6 @@ def invoke_llm_with_structured_output(
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
*,
@@ -108,30 +101,23 @@ def invoke_llm_with_structured_output(
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
"""
Invoke large language model with structured output.
Invoke large language model with structured output
1. This method invokes model_instance.invoke_llm with json_schema
2. Try to parse the result as structured output
This method invokes model_instance.invoke_llm with json_schema and parses
the result as structured output.
:param provider: model provider name
:param model_schema: model schema entity
:param model_instance: model instance to invoke
:param prompt_messages: prompt messages
:param json_schema: json schema for structured output
:param json_schema: json schema
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:param tenant_id: tenant ID for file reference conversion. When provided and
json_schema contains file reference fields (format: "dify-file-ref"),
file IDs in the output will be automatically converted to File objects.
:return: full response or stream response chunk generator result
"""
# handle native json schema
model_parameters_with_json_schema: dict[str, Any] = {
**(model_parameters or {}),
@@ -167,18 +153,8 @@ def invoke_llm_with_structured_output(
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
)
structured_output = _parse_structured_output(llm_result.message.content)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
return LLMResultWithStructuredOutput(
structured_output=structured_output,
structured_output=_parse_structured_output(llm_result.message.content),
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
@@ -210,18 +186,8 @@ def invoke_llm_with_structured_output(
delta=event.delta,
)
structured_output = _parse_structured_output(result_text)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
yield LLMResultChunkWithStructuredOutput(
structured_output=structured_output,
structured_output=_parse_structured_output(result_text),
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
@@ -236,87 +202,6 @@ def invoke_llm_with_structured_output(
return generator()
@overload
def invoke_llm_with_pydantic_model(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
def invoke_llm_with_pydantic_model(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with a Pydantic output model.
This helper generates a JSON schema from the Pydantic model, invokes the
structured-output LLM path, and validates the result in non-streaming mode.
"""
if stream:
raise ValueError("invoke_llm_with_pydantic_model only supports stream=False")
json_schema = _schema_from_pydantic(output_model)
result = invoke_llm_with_structured_output(
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=False,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
)
structured_output = result.structured_output
if structured_output is None:
raise OutputParserError("Structured output is empty")
validated_output = _validate_structured_output(output_model, structured_output)
return result.model_copy(update={"structured_output": validated_output})
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
return output_model.model_json_schema()
def _validate_structured_output(
output_model: type[T],
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
try:
validated_output = output_model.model_validate(structured_output)
except ValidationError as exc:
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
return validated_output.model_dump(mode="python")
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,

View File

@@ -1,45 +0,0 @@
"""Utility functions for LLM generator."""
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
"""
Deserialize list of dicts to list[PromptMessage].
Expected format:
[
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."},
]
"""
result: list[PromptMessage] = []
for msg in messages:
role = PromptMessageRole.value_of(msg["role"])
content = msg.get("content", "")
match role:
case PromptMessageRole.USER:
result.append(UserPromptMessage(content=content))
case PromptMessageRole.ASSISTANT:
result.append(AssistantPromptMessage(content=content))
case PromptMessageRole.SYSTEM:
result.append(SystemPromptMessage(content=content))
case PromptMessageRole.TOOL:
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
return result
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
"""
Serialize list[PromptMessage] to list of dicts.
"""
return [{"role": msg.role.value, "content": msg.content} for msg in messages]

View File

@@ -1,267 +0,0 @@
# Memory Module
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
## Overview
The memory module contains two types of memory implementations:
1. **TokenBufferMemory** - Conversation-level memory (existing)
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ Memory Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ TokenBufferMemory │ │
│ │ Scope: Conversation │ │
│ │ Storage: Database (Message table) │ │
│ │ Key: conversation_id │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ NodeTokenBufferMemory │ │
│ │ Scope: Node within Conversation │ │
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
```
---
## TokenBufferMemory (Existing)
### Purpose
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
### Key Features
- **Conversation-scoped**: All messages within a conversation are candidates
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
- **Token-limited**: Truncates history to fit within `max_token_limit`
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
### Data Flow
```
Message Table TokenBufferMemory LLM
│ │ │
│ SELECT * FROM messages │ │
│ WHERE conversation_id = ? │ │
│ ORDER BY created_at DESC │ │
├─────────────────────────────────▶│ │
│ │ │
│ extract_thread_messages() │
│ │ │
│ build_prompt_message_with_files() │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage]
│ ├───────────────────────▶│
│ │ │
```
### Thread Extraction
When a user regenerates a response, a new thread is created:
```
Message A (user)
└── Message A' (assistant)
└── Message B (user)
└── Message B' (assistant)
└── Message A'' (assistant, regenerated) ← New thread
└── Message C (user)
└── Message C' (assistant)
```
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
### Usage
```python
from core.memory.token_buffer_memory import TokenBufferMemory
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
```
---
## NodeTokenBufferMemory
### Purpose
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
### Use Cases
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
3. **Specialized Agents**: Each agent node maintains its own dialogue history
### Design: Zero Extra Storage
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
Each LLM node execution outputs:
```python
outputs = {
"text": clean_text,
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
...
}
```
This `outputs["context"]` contains:
- All previous user/assistant messages (excluding system prompt)
- The current assistant response
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
### Benefits
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|--------|----------------------------|--------------------------------|
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
| Migration | Required | None |
| Complexity | High | Low |
### Data Flow
```
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
│ │ │
│ │◀── get_history_prompt_messages()
│ │ │
│ SELECT outputs FROM │ │
│ workflow_node_executions │ │
│ WHERE workflow_run_id = ? │ │
│ AND node_id = ? │ │
│◀─────────────────────────────────┤ │
│ │ │
│ outputs["context"] │ │
├─────────────────────────────────▶│ │
│ │ │
│ deserialize PromptMessages │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage] │
│ ├──────────────────────────▶│
│ │ │
```
### Thread Tracking
Thread extraction still uses `Message` table's `parent_message_id` structure:
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
2. Get the last completed `workflow_run_id` in the thread
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
### API
```python
class NodeTokenBufferMemory:
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
"""Initialize node-level memory."""
...
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
Reads from last completed execution's outputs["context"].
"""
...
# Legacy methods (no-op, kept for compatibility)
def add_messages(self, *args, **kwargs) -> None: pass
def flush(self) -> None: pass
def clear(self) -> None: pass
```
### Configuration
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
```python
class MemoryMode(StrEnum):
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class MemoryConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: MemoryWindowConfig | None = None
query_prompt_template: str | None = None
mode: MemoryMode = MemoryMode.CONVERSATION
```
**Mode Behavior:**
| Mode | Memory Class | Scope | Availability |
| -------------- | --------------------- | ------------------------ | ------------- |
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
---
## Comparison
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
| -------------- | ------------------------ | ---------------------------------- |
| Scope | Conversation | Node within Conversation |
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
| Thread Support | Yes | Yes |
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
| Token Limit | Yes | Yes |
| Use Case | Standard chat apps | Complex workflows |
---
## Extending to Other Nodes
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
2. The `NodeTokenBufferMemory` will automatically pick it up
Nodes that could potentially support this:
- `question_classifier`
- `parameter_extractor`
- `agent`
---
## Future Considerations
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
2. **Compression**: For very long conversations, consider summarization strategies
3. **Extension**: Other nodes may benefit from node-level memory

View File

@@ -1,11 +0,0 @@
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import (
NodeTokenBufferMemory,
)
from core.memory.token_buffer_memory import TokenBufferMemory
__all__ = [
"BaseMemory",
"NodeTokenBufferMemory",
"TokenBufferMemory",
]

View File

@@ -1,83 +0,0 @@
"""
Base memory interfaces and types.
This module defines the common protocol for memory implementations.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
class BaseMemory(ABC):
"""
Abstract base class for memory implementations.
Provides a common interface for both conversation-level and node-level memory.
"""
@abstractmethod
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Sequence of PromptMessage for LLM context
"""
pass
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt as formatted text.
:param human_prefix: Prefix for human messages
:param ai_prefix: Prefix for assistant messages
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Formatted history text
"""
from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -1,197 +0,0 @@
"""
Node-level Token Buffer Memory for Chatflow.
This module provides node-scoped memory within a conversation.
Each LLM node in a workflow can maintain its own independent conversation history.
Note: This is only available in Chatflow (advanced-chat mode) because it requires
both conversation_id and node_id.
Design:
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
- No separate storage needed - the context is already saved during node execution
- Thread tracking leverages Message table's parent_message_id structure
"""
import logging
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
class NodeTokenBufferMemory(BaseMemory):
"""
Node-level Token Buffer Memory.
Provides node-scoped memory within a conversation. Each LLM node can maintain
its own independent conversation history.
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
which is already saved during node execution. No separate storage needed.
"""
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
self.app_id = app_id
self.conversation_id = conversation_id
self.node_id = node_id
self.tenant_id = tenant_id
self.model_instance = model_instance
def _get_thread_workflow_run_ids(self) -> list[str]:
"""
Get workflow_run_ids for the current thread by querying Message table.
Returns workflow_run_ids in chronological order (oldest first).
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(Message)
.where(Message.conversation_id == self.conversation_id)
.order_by(Message.created_at.desc())
.limit(500)
)
messages = list(session.scalars(stmt).all())
if not messages:
return []
# Extract thread messages using existing logic
thread_messages = extract_thread_messages(messages)
# For newly created message, its answer is temporarily empty, skip it
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
# Reverse to get chronological order, extract workflow_run_ids
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
"""Deserialize a dict to PromptMessage based on role."""
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
return UserPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
return AssistantPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.SYSTEM, "system"):
return SystemPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.TOOL, "tool"):
return ToolPromptMessage.model_validate(msg_dict)
else:
return PromptMessage.model_validate(msg_dict)
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
"""Deserialize context data from outputs to list of PromptMessage."""
messages = []
for msg_dict in context_data:
try:
msg = self._deserialize_prompt_message(msg_dict)
msg = self._restore_multimodal_content(msg)
messages.append(msg)
except Exception as e:
logger.warning("Failed to deserialize prompt message: %s", e)
return messages
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
"""
Restore multimodal content (base64 or url) from file_ref.
When context is saved, base64_data is cleared to save storage space.
This method restores the content by parsing file_ref (format: "method:id_or_url").
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, restoring multimodal data from file references
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# restore_multimodal_content preserves the concrete subclass type
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
History is read directly from the last completed node execution's outputs["context"].
"""
_ = message_limit # unused, kept for interface compatibility
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
if not thread_workflow_run_ids:
return []
# Get the last completed workflow_run_id (contains accumulated context)
last_run_id = thread_workflow_run_ids[-1]
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
WorkflowNodeExecutionModel.node_id == self.node_id,
WorkflowNodeExecutionModel.status == "succeeded",
)
execution = session.scalars(stmt).first()
if not execution:
return []
outputs = execution.outputs_dict
if not outputs:
return []
context_data = outputs.get("context")
if not context_data or not isinstance(context_data, list):
return []
prompt_messages = self._deserialize_context(context_data)
if not prompt_messages:
return []
# Truncate by token limit
try:
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
while current_tokens > max_token_limit and len(prompt_messages) > 1:
prompt_messages.pop(0)
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
except Exception as e:
logger.warning("Failed to count tokens for truncation: %s", e)
return prompt_messages

View File

@@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
@@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class TokenBufferMemory(BaseMemory):
class TokenBufferMemory:
def __init__(
self,
conversation: Conversation,
@@ -115,14 +115,10 @@ class TokenBufferMemory(BaseMemory):
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
"""
@@ -204,3 +200,44 @@ class TokenBufferMemory(BaseMemory):
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
:param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -91,9 +91,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
@@ -279,5 +276,7 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
return super().is_empty() and not self.tool_call_id
if not super().is_empty() and not self.tool_call_id:
return False
return True

View File

@@ -320,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
raise InvokeRateLimitError(description=args.get("description"))
raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
raise InvokeAuthorizationError(description=args.get("description"))
raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
raise InvokeBadRequestError(description=args.get("description"))
raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
raise InvokeConnectionError(description=args.get("description"))
raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
raise InvokeServerUnavailableError(description=args.get("description"))
raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -339,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
raise TriggerPluginInvokeError(description=error_object.get("description"))
raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
raise EventIgnoreError(description=error_object.get("description"))
raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:

View File

@@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.file import file_manager
from core.file.models import File
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File],
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File],
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File],
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
def _set_histories_variable(
self,
memory: BaseMemory,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,

View File

@@ -1,4 +1,3 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
@@ -6,13 +5,6 @@ from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageRole
class MemoryMode(StrEnum):
"""Memory mode for LLM nodes."""
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class ChatModelMessage(BaseModel):
"""
Chat Message.
@@ -56,4 +48,3 @@ class MemoryConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
mode: MemoryMode = MemoryMode.CONVERSATION

View File

@@ -1,7 +1,7 @@
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
@@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:
def _append_chat_histories(
self,
memory: BaseMemory,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity,
@@ -52,7 +52,7 @@ class PromptTransform:
def _get_history_messages_from_memory(
self,
memory: BaseMemory,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int,
human_prefix: str | None = None,
@@ -73,7 +73,7 @@ class PromptTransform:
return memory.get_history_prompt_text(**kwargs)
def _get_history_messages_list_from_memory(
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
) -> list[PromptMessage]:
"""Get memory messages."""
return list(

View File

@@ -1047,8 +1047,6 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
if not isinstance(tool_input.value, list):
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
@@ -1058,11 +1056,6 @@ class ToolManager:
elif tool_input.type == "mixed":
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text
elif tool_input.type == "mention":
# Mention type not supported in agent mode
raise ToolParameterError(
f"Mention type not supported in agent for parameter '{parameter.name}'"
)
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
runtime_parameters[parameter.name] = parameter_value

View File

@@ -4,7 +4,6 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
@@ -21,7 +20,6 @@ from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
ArrayVariable,
FileVariable,
@@ -44,8 +42,6 @@ __all__ = [
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArrayPromptMessageSegment",
"ArrayPromptMessageVariable",
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",

View File

@@ -6,7 +6,6 @@ from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
from core.model_runtime.entities import PromptMessage
from .types import SegmentType
@@ -209,15 +208,6 @@ class ArrayBooleanSegment(ArraySegment):
value: Sequence[bool]
class ArrayPromptMessageSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
value: Sequence[PromptMessage]
def to_object(self):
"""Convert to JSON-serializable format for database storage and frontend."""
return [msg.model_dump() for msg in self.value]
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
@@ -258,7 +248,6 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
),
Discriminator(get_segment_discriminator),
]

View File

@@ -45,7 +45,6 @@ class SegmentType(StrEnum):
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_PROMPT_MESSAGE = "array[message]"
NONE = "none"

View File

@@ -3,10 +3,8 @@ from typing import Any
import orjson
from core.model_runtime.entities import PromptMessage
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from .segments import ArrayFileSegment, FileSegment, Segment
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
@@ -18,7 +16,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
def segment_orjson_default(o: Any):
"""Default function for orjson serialization of Segment types"""
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
@@ -26,8 +24,6 @@ def segment_orjson_default(o: Any):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
elif isinstance(o, PromptMessage):
return o.model_dump()
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")

View File

@@ -12,7 +12,6 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
@@ -111,10 +110,6 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
@@ -165,7 +160,6 @@ Variable: TypeAlias = Annotated[
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),

File diff suppressed because it is too large Load Diff

View File

@@ -63,7 +63,6 @@ class NodeType(StrEnum):
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
GROUP = "group"
@property
def is_trigger_node(self) -> bool:
@@ -253,7 +252,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
class WorkflowNodeExecutionStatus(StrEnum):

View File

@@ -307,14 +307,7 @@ class Graph:
if not node_configs:
raise ValueError("Graph must have at least one node")
# Filter out UI-only node types:
# - custom-note: top-level type (node_config.type == "custom-note")
# - group: data-level type (node_config.data.type == "group")
node_configs = [
node_config
for node_config in node_configs
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
]
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@@ -93,8 +93,8 @@ class EventHandler:
Args:
event: The event to handle
"""
# Events in loops, iterations, or extractor groups are always collected
if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
# Events in loops or iterations are always collected
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return
return self._dispatch(event)
@@ -125,11 +125,6 @@ class EventHandler:
Args:
event: The node started event
"""
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_started(event)
return
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
@@ -169,11 +164,6 @@ class EventHandler:
Args:
event: The node succeeded event
"""
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_success(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
@@ -236,11 +226,6 @@ class EventHandler:
Args:
event: The node failed event
"""
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_failed(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
@@ -360,57 +345,3 @@ class EventHandler:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)
def _is_extractor_node(self, node_id: str) -> bool:
"""
Check if node_id represents an extractor node (has parent_node_id).
Extractor nodes extract values from list[PromptMessage] for their parent node.
They have a parent_node_id field pointing to their parent node.
"""
node = self._graph.nodes.get(node_id)
if node is None:
return False
return node.node_data.is_extractor_node
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
"""
Handle extractor node started event.
Extractor nodes don't need full execution tracking, just collect the event.
"""
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
"""
Handle extractor node success event.
Extractor nodes need special handling:
- Store outputs in variable pool (for reference by other nodes)
- Accumulate token usage
- Collect the event for logging
- Do NOT process edges or enqueue next nodes (parent node handles that)
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
Handle extractor node failed event.
Extractor node failures are collected for logging,
but the parent node is responsible for handling the error.
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Collect the event for logging
self._event_collector.collect(event)

View File

@@ -68,7 +68,6 @@ class _NodeRuntimeSnapshot:
predecessor_node_id: str | None
iteration_id: str | None
loop_id: str | None
mention_parent_id: str | None
created_at: datetime
@@ -231,7 +230,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
}
domain_execution = WorkflowNodeExecution(
@@ -258,7 +256,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot

View File

@@ -21,12 +21,6 @@ class GraphNodeEventBase(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""Parent node id if this is an extractor node event.
When set, indicates this event belongs to an extractor node that
is extracting values for the specified parent node.
"""
# The version of the node, or "1" if not specified.
node_version: str = "1"

View File

@@ -12,20 +12,11 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.file import File, FileTransferMethod
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import MemoryMode
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
@@ -145,9 +136,6 @@ class AgentNode(Node[AgentNodeData]):
)
return
# Fetch memory for node memory saving
memory = self._fetch_memory_for_save()
try:
yield from self._transform_message(
messages=message_stream,
@@ -161,7 +149,6 @@ class AgentNode(Node[AgentNodeData]):
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
memory=memory,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
@@ -408,20 +395,8 @@ class AgentNode(Node[AgentNodeData]):
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
"""
node_data = self.node_data
memory_config = node_data.memory
if not memory_config:
return None
# get conversation id (required for both modes in Chatflow)
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
@@ -429,26 +404,16 @@ class AgentNode(Node[AgentNodeData]):
return None
conversation_id = conversation_id_variable.value
# Return appropriate memory type based on mode
if memory_config.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
return NodeTokenBufferMemory(
app_id=self.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=self.tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == self.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
@@ -492,136 +457,6 @@ class AgentNode(Node[AgentNodeData]):
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _fetch_memory_for_save(self) -> BaseMemory | None:
"""
Fetch memory instance for saving node memory.
This is a simplified version that doesn't require model_instance.
"""
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
node_data = self.node_data
if not node_data.memory:
return None
# Get conversation_id
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_var, StringSegment):
return None
conversation_id = conversation_id_var.value
# Return appropriate memory type based on mode
if node_data.memory.mode == MemoryMode.NODE:
# For node memory, we need a model_instance for token counting
# Use a simple default model for this purpose
try:
model_instance = ModelManager().get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
)
except Exception:
return None
return NodeTokenBufferMemory(
app_id=self.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=self.tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory doesn't need saving here
return None
def _build_context(
self,
parameters_for_log: dict[str, Any],
user_query: str,
assistant_response: str,
agent_logs: list[AgentLogEvent],
) -> list[PromptMessage]:
"""
Build context from user query, tool calls, and assistant response.
Format: user -> assistant(with tool_calls) -> tool -> assistant
The context includes:
- Current user query (always present, may be empty)
- Assistant message with tool_calls (if tools were called)
- Tool results
- Assistant's final response
"""
context_messages: list[PromptMessage] = []
# Always add user query (even if empty, to maintain conversation structure)
context_messages.append(UserPromptMessage(content=user_query or ""))
# Extract actual tool calls from agent logs
# Only include logs with label starting with "CALL " - these are real tool invocations
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
for log in agent_logs:
if log.status == "success" and log.label and log.label.startswith("CALL "):
# Extract tool name from label (format: "CALL tool_name")
tool_name = log.label[5:] # Remove "CALL " prefix
tool_call_id = log.message_id
# Parse tool response from data
data = log.data or {}
tool_response = ""
# Try to extract the actual tool response
if "tool_response" in data:
tool_response = data["tool_response"]
elif "output" in data:
tool_response = data["output"]
elif "result" in data:
tool_response = data["result"]
if isinstance(tool_response, dict):
tool_response = str(tool_response)
# Get tool input for arguments
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
if isinstance(tool_input, dict):
import json
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
else:
tool_input_str = str(tool_input) if tool_input else ""
if tool_response:
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_name,
arguments=tool_input_str,
),
)
)
tool_results.append((tool_call_id, tool_name, str(tool_response)))
# Add assistant message with tool_calls if there were tool calls
if tool_calls:
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
# Add tool result messages
for tool_call_id, tool_name, result in tool_results:
context_messages.append(
ToolPromptMessage(
content=result,
tool_call_id=tool_call_id,
name=tool_name,
)
)
# Add final assistant response
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
@@ -632,7 +467,6 @@ class AgentNode(Node[AgentNodeData]):
node_type: NodeType,
node_id: str,
node_execution_id: str,
memory: BaseMemory | None = None,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@@ -877,12 +711,6 @@ class AgentNode(Node[AgentNodeData]):
is_final=True,
)
# Get user query from parameters for building context
user_query = parameters_for_log.get("query", "")
# Build context from history, user query, tool calls and assistant response
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -891,7 +719,6 @@ class AgentNode(Node[AgentNodeData]):
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
"context": context,
**variables,
},
metadata={

View File

@@ -1,10 +1,4 @@
from .entities import (
BaseIterationNodeData,
BaseIterationState,
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
)
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [

View File

@@ -175,16 +175,6 @@ class BaseNodeData(ABC, BaseModel):
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Parent node ID when this node is used as an extractor.
# If set, this node is an "attached" extractor node that extracts values
# from list[PromptMessage] for the parent node's parameters.
parent_node_id: str | None = None
@property
def is_extractor_node(self) -> bool:
"""Check if this node is an extractor node (has parent_node_id)."""
return self.parent_node_id is not None
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:

View File

@@ -270,87 +270,10 @@ class Node(Generic[NodeDataT]):
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
"""
Find all extractor node configurations that have parent_node_id == self._node_id.
Returns:
List of node configuration dicts for extractor nodes
"""
nodes = self.graph_config.get("nodes", [])
extractor_configs = []
for node_config in nodes:
node_data = node_config.get("data", {})
if node_data.get("parent_node_id") == self._node_id:
extractor_configs.append(node_config)
return extractor_configs
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Execute all extractor nodes associated with this node.
Extractor nodes are nodes with parent_node_id == self._node_id.
They are executed before the main node to extract values from list[PromptMessage].
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
extractor_configs = self._find_extractor_node_configs()
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
if not extractor_configs:
return
for config in extractor_configs:
node_id = config.get("id")
node_data = config.get("data", {})
node_type_str = node_data.get("type")
if not node_id or not node_type_str:
continue
# Get node class
try:
node_type = NodeType(node_type_str)
except ValueError:
continue
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
continue
node_version = str(node_data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
continue
# Instantiate and execute the extractor node
extractor_node = node_cls(
id=node_id,
config=config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Execute and process extractor node events
for event in extractor_node.run():
# Tag event with parent node id for stream ordering and history tracking
if isinstance(event, GraphNodeEventBase):
event.in_mention_parent_id = self._node_id
if isinstance(event, NodeRunSucceededEvent):
# Store extractor node outputs in variable pool
outputs: Mapping[str, Any] = event.node_run_result.outputs
for variable_name, variable_value in outputs.items():
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
if not isinstance(event, NodeRunStreamChunkEvent):
yield event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Step 1: Execute associated extractor nodes before main node execution
yield from self._execute_mention_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=execution_id,

View File

@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@@ -58,28 +58,9 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class PromptMessageContext(BaseModel):
"""Context variable reference in prompt template.
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
This will be expanded to list[PromptMessage] at runtime.
"""
model_config = ConfigDict(populate_by_name=True)
value_selector: Sequence[str] = Field(alias="$context")
# Union type for prompt template items (static message or context variable reference)
PromptTemplateItem: TypeAlias = Annotated[
LLMNodeChatModelMessage | PromptMessageContext,
Field(discriminator=None),
]
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: MemoryConfig | None = None
context: ContextConfig

View File

@@ -8,20 +8,12 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
@@ -94,56 +86,25 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool,
app_id: str,
tenant_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
node_id: str = "",
) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
:param variable_pool: Variable pool containing system variables
:param app_id: Application ID
:param tenant_id: Tenant ID
:param node_data_memory: Memory configuration
:param model_instance: Model instance for token counting
:param node_id: Node ID in the workflow (required for node mode)
:return: Memory instance or None if not applicable
"""
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# Get conversation_id from variable pool (required for both modes in Chatflow)
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# Return appropriate memory type based on mode
if node_data_memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
if not node_id:
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return NodeTokenBufferMemory(
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
tenant_id=tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
@@ -209,87 +170,3 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
)
session.execute(stmt)
session.commit()
def build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
"""
context_messages: list[PromptMessage] = [
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
If file_ref is present, clears base64_data and url (they can be restored later).
Otherwise, truncates base64_data as fallback for legacy data.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, handling multi-modal data based on file_ref availability
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
if item.file_ref:
# Clear base64 and url, keep file_ref for later restoration
new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
else:
# Fallback: truncate base64_data if no file_ref (legacy data)
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})
def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
"""
Restore multimodal content (base64 or url) in a list of PromptMessages.
When context is saved, base64_data is cleared to save storage space.
This function restores the content by parsing file_ref in each MultiModalPromptMessageContent.
Args:
messages: List of PromptMessages that may contain truncated multimodal content
Returns:
List of PromptMessages with restored multimodal content
"""
from core.file import file_manager
return [_restore_message_content(msg, file_manager) for msg in messages]
def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage:
"""Restore multimodal content in a single PromptMessage."""
content = message.content
if content is None or isinstance(content, str):
return message
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})

View File

@@ -7,7 +7,7 @@ import logging
import re
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
@@ -16,7 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
@@ -51,7 +51,6 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
@@ -88,7 +87,6 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
PromptMessageContext,
)
from .exc import (
InvalidContextStructureError,
@@ -161,9 +159,8 @@ class LLMNode(Node[LLMNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
prompt_template = self.node_data.prompt_template
static_messages, context_refs, template_order = self._parse_prompt_template()
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
@@ -211,10 +208,8 @@ class LLMNode(Node[LLMNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
query: str | None = None
@@ -225,40 +220,21 @@ class LLMNode(Node[LLMNodeData]):
):
query = query_variable.text
# Get prompt messages
prompt_messages: Sequence[PromptMessage]
stop: Sequence[str] | None
if isinstance(prompt_template, list) and context_refs:
prompt_messages, stop = self._build_prompt_messages_with_context(
context_refs=context_refs,
template_order=template_order,
static_messages=static_messages,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config,
context_files=context_files,
)
else:
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
self.node_data.prompt_template,
),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
generator = LLMNode.invoke_llm(
@@ -274,7 +250,6 @@ class LLMNode(Node[LLMNodeData]):
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self.node_data.reasoning_format,
tenant_id=self.tenant_id,
)
structured_output: LLMStructuredOutput | None = None
@@ -326,7 +301,6 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
@@ -393,7 +367,6 @@ class LLMNode(Node[LLMNodeData]):
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
tenant_id: str | None = None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@@ -417,7 +390,6 @@ class LLMNode(Node[LLMNodeData]):
stop=list(stop or []),
stream=True,
user=user_id,
tenant_id=tenant_id,
)
else:
request_start_time = time.perf_counter()
@@ -609,212 +581,6 @@ class LLMNode(Node[LLMNodeData]):
return messages
def _parse_prompt_template(
self,
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
"""
Parse prompt_template to separate static messages and context references.
Returns:
Tuple of (static_messages, context_refs, template_order)
- static_messages: list of LLMNodeChatModelMessage
- context_refs: list of PromptMessageContext
- template_order: list of (index, type) tuples preserving original order
"""
prompt_template = self.node_data.prompt_template
static_messages: list[LLMNodeChatModelMessage] = []
context_refs: list[PromptMessageContext] = []
template_order: list[tuple[int, str]] = []
if isinstance(prompt_template, list):
for idx, item in enumerate(prompt_template):
if isinstance(item, PromptMessageContext):
context_refs.append(item)
template_order.append((idx, "context"))
else:
static_messages.append(item)
template_order.append((idx, "static"))
# Transform static messages for jinja2
if static_messages:
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
return static_messages, context_refs, template_order
def _build_prompt_messages_with_context(
self,
*,
context_refs: list[PromptMessageContext],
template_order: list[tuple[int, str]],
static_messages: list[LLMNodeChatModelMessage],
query: str | None,
files: Sequence[File],
context: str | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
context_files: list[File],
) -> tuple[list[PromptMessage], Sequence[str] | None]:
"""
Build prompt messages by combining static messages and context references in DSL order.
Returns:
Tuple of (prompt_messages, stop_sequences)
"""
variable_pool = self.graph_runtime_state.variable_pool
# Process messages in DSL order: iterate once and handle each type directly
combined_messages: list[PromptMessage] = []
context_idx = 0
static_idx = 0
for _, type_ in template_order:
if type_ == "context":
# Handle context reference
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
# Restore multimodal content (base64/url) that was truncated when saving context
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
combined_messages.extend(restored_messages)
context_idx += 1
else:
# Handle static message
static_msg = static_messages[static_idx]
processed_msgs = LLMNode.handle_list_messages(
messages=[static_msg],
context=context,
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(processed_msgs)
static_idx += 1
# Append memory messages
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=self.node_data.memory,
model_config=model_config,
)
combined_messages.extend(memory_messages)
# Append current query if provided
if query:
query_message = LLMNodeChatModelMessage(
text=query,
role=PromptMessageRole.USER,
edition_type="basic",
)
query_msgs = LLMNode.handle_list_messages(
messages=[query_message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(query_msgs)
# Handle files (sys_files and context_files)
combined_messages = self._append_files_to_messages(
messages=combined_messages,
sys_files=files,
context_files=context_files,
model_config=model_config,
)
# Filter empty messages and get stop sequences
combined_messages = self._filter_messages(combined_messages, model_config)
stop = self._get_stop_sequences(model_config)
return combined_messages, stop
def _append_files_to_messages(
self,
*,
messages: list[PromptMessage],
sys_files: Sequence[File],
context_files: list[File],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
"""Append sys_files and context_files to messages."""
vision_enabled = self.node_data.vision.enabled
vision_detail = self.node_data.vision.configs.detail
# Handle sys_files (will be deprecated later)
if vision_enabled and sys_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
# Handle context_files
if vision_enabled and context_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
for file in context_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
return messages
def _filter_messages(
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> list[PromptMessage]:
"""Filter empty messages and unsupported content types."""
filtered_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message.content, list):
filtered_content: list[PromptMessageContentUnionTypes] = []
for content_item in message.content:
# Skip non-text content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
filtered_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
feature_map = {
PromptMessageContentType.IMAGE: ModelFeature.VISION,
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
}
required_feature = feature_map.get(content_item.type)
if required_feature and required_feature not in model_config.model_schema.features:
continue
filtered_content.append(content_item)
# Simplify single text content
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
message.content = filtered_content[0].data
else:
message.content = filtered_content
if not message.is_empty():
filtered_messages.append(message)
if not filtered_messages:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_messages
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
"""Get stop sequences from model config."""
return model_config.stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}
@@ -1012,7 +778,7 @@ class LLMNode(Node[LLMNodeData]):
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: BaseMemory | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
@@ -1571,7 +1337,7 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
@@ -1588,7 +1354,7 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:

View File

@@ -7,7 +7,7 @@ from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -145,10 +145,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
if (
@@ -246,10 +244,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# transform result into standard format
result = self._transform_result(data=node_data, result=result or {})
# Build context from prompt messages and response
assistant_response = json.dumps(result, ensure_ascii=False)
context = llm_utils.build_context(prompt_messages, assistant_response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -258,7 +252,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
"context": context,
**result,
},
metadata={
@@ -306,7 +299,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@@ -388,7 +381,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -426,7 +419,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -460,7 +453,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -688,7 +681,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@@ -715,7 +708,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -96,10 +96,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
@@ -199,15 +197,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# Build context from prompt messages and response
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
context = llm_utils.build_context(prompt_messages, assistant_response)
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
"context": context,
}
return NodeRunResult(
@@ -319,7 +312,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@@ -1,63 +1,11 @@
import re
from collections.abc import Sequence
from typing import Any, Literal, Self, Union
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
# Pattern to match mention value format: {{@node.context@}}instruction
# The placeholder {{@node.context@}} must appear at the beginning
# Format: {{@agent_node_id.context@}} where agent_node_id is dynamic, context is fixed
MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL)
def parse_mention_value(value: str) -> tuple[str, str]:
"""Parse mention value into (node_id, instruction).
Args:
value: The mention value string like "{{@llm.context@}}extract keywords"
Returns:
Tuple of (node_id, instruction)
Raises:
ValueError: If value format is invalid
"""
match = MENTION_VALUE_PATTERN.match(value)
if not match:
raise ValueError(
"For mention type, value must start with {{@node.context@}} placeholder, "
"e.g., '{{@llm.context@}}extract keywords'"
)
return match.group(1), match.group(2)
class MentionConfig(BaseModel):
"""Configuration for extracting value from context variable.
Used when a tool parameter needs to be extracted from list[PromptMessage]
context using an extractor LLM node.
Note: instruction is embedded in the value field as "{{@node.context@}}instruction"
"""
# ID of the extractor LLM node
extractor_node_id: str
# Output variable selector from extractor node
# e.g., ["text"], ["structured_output", "query"]
output_selector: Sequence[str]
# Strategy when output is None
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
# Default value when null_strategy is "use_default"
# Type should match the parameter's expected type
default_value: Any = None
class ToolEntity(BaseModel):
provider_id: str
@@ -87,9 +35,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant", "mention"]
# Required config for mention type, extracting value from context variable
mention_config: MentionConfig | None = None
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
@@ -102,9 +48,6 @@ class ToolNodeData(BaseNodeData, ToolEntity):
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "mention":
# Skip here, will be validated in model_validator
pass
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
@@ -115,26 +58,6 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a string, int, float, bool or dict")
return typ
@model_validator(mode="after")
def check_mention_type(self) -> Self:
"""Validate mention type with mention_config."""
if self.type != "mention":
return self
value = self.value
if value is None:
return self
if not isinstance(value, str):
raise ValueError("value must be a string for mention type")
# For mention type, value must match format: {{@node.context@}}instruction
# This will raise ValueError if format is invalid
parse_mention_value(value)
# mention_config is required for mention type
if self.mention_config is None:
raise ValueError("mention_config is required for mention type")
return self
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version

View File

@@ -1,10 +1,7 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
logger = logging.getLogger(__name__)
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@@ -187,7 +184,6 @@ class ToolNode(Node[ToolNodeData]):
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
for_log (bool): Whether to generate parameters for logging.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
@@ -203,37 +199,14 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
if not isinstance(tool_input.value, list):
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
selector = tool_input.value
variable = variable_pool.get(selector)
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {selector} does not exist")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type == "mention":
# Mention type: get value from extractor node's output
if tool_input.mention_config is None:
raise ToolParameterError(
f"mention_config is required for mention type parameter '{parameter_name}'"
)
mention_config = tool_input.mention_config.model_dump()
try:
parameter_value, found = variable_pool.resolve_mention(
mention_config, parameter_name=parameter_name
)
if not found and parameter.required:
raise ToolParameterError(
f"Extractor output not found for required parameter '{parameter_name}'"
)
if not found:
continue
except ValueError as e:
raise ToolParameterError(str(e)) from e
elif tool_input.type in {"mixed", "constant"}:
template = str(tool_input.value)
segment_group = variable_pool.convert_template(template)
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
@@ -515,12 +488,8 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
if isinstance(input.value, list):
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "mention":
# Mention type: value is handled by extractor node, no direct variable reference
pass
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass

View File

@@ -268,58 +268,6 @@ class VariablePool(BaseModel):
continue
self.add(selector, value)
def resolve_mention(
self,
mention_config: Mapping[str, Any],
/,
*,
parameter_name: str = "",
) -> tuple[Any, bool]:
"""
Resolve a mention parameter value from an extractor node's output.
Mention parameters reference values extracted by an extractor LLM node
from list[PromptMessage] context.
Args:
mention_config: A dict containing:
- extractor_node_id: ID of the extractor LLM node
- output_selector: Selector path for the output variable (e.g., ["text"])
- null_strategy: "raise_error" or "use_default"
- default_value: Value to use when null_strategy is "use_default"
parameter_name: Name of the parameter being resolved (for error messages)
Returns:
Tuple of (resolved_value, found):
- resolved_value: The extracted value, or default_value if not found
- found: True if value was found, False if using default
Raises:
ValueError: If extractor_node_id is missing, or if null_strategy is
"raise_error" and the value is not found
"""
extractor_node_id = mention_config.get("extractor_node_id")
if not extractor_node_id:
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
output_selector = list(mention_config.get("output_selector", []))
null_strategy = mention_config.get("null_strategy", "raise_error")
default_value = mention_config.get("default_value")
# Build full selector: [extractor_node_id, ...output_selector]
full_selector = [extractor_node_id] + output_selector
variable = self.get(full_selector)
if variable is None:
if null_strategy == "use_default":
return default_value, False
raise ValueError(
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
f"not found for parameter '{parameter_name}'"
)
return variable.value, True
@classmethod
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""

View File

@@ -119,14 +119,16 @@ elif [[ "${MODE}" == "job" ]]; then
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
export PORT=${DIFY_PORT:-5001}
exec python -m app
else
exec gunicorn \
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
app:socketio_app
fi
fi

View File

@@ -0,0 +1,5 @@
import socketio # type: ignore[reportMissingTypeStubs]
from configs import dify_config
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)

View File

@@ -4,7 +4,6 @@ from uuid import uuid4
from configs import dify_config
from core.file import File
from core.model_runtime.entities import PromptMessage
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
@@ -12,7 +11,6 @@ from core.variables.segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
@@ -31,7 +29,6 @@ from core.variables.variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
BooleanVariable,
FileVariable,
@@ -64,7 +61,6 @@ SEGMENT_TO_VARIABLE_MAP = {
ArrayFileSegment: ArrayFileVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
ArrayStringSegment: ArrayStringVariable,
BooleanSegment: BooleanVariable,
FileSegment: FileVariable,
@@ -160,13 +156,7 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, File):
return FileSegment(value=value)
if isinstance(value, PromptMessage):
# Single PromptMessage should be wrapped in a list
return ArrayPromptMessageSegment(value=[value])
if isinstance(value, list):
# Check if all items are PromptMessage
if value and all(isinstance(item, PromptMessage) for item in value):
return ArrayPromptMessageSegment(value=value)
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if all(isinstance(item, ArraySegment) for item in items):
@@ -210,7 +200,6 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
}
@@ -285,10 +274,6 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
):
segment_class = _segment_factory[inferred_type]
return segment_class(value_type=inferred_type, value=value)
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")

View File

@@ -0,0 +1,17 @@
from flask_restx import fields
online_user_partial_fields = {
"user_id": fields.String,
"username": fields.String,
"avatar": fields.String,
"sid": fields.String,
}
workflow_online_users_fields = {
"workflow_id": fields.String,
"users": fields.List(fields.Nested(online_user_partial_fields)),
}
online_user_list_fields = {
"data": fields.List(fields.Nested(workflow_online_users_fields)),
}

View File

@@ -0,0 +1,96 @@
from flask_restx import fields
from libs.helper import AvatarUrlField, TimestampField
# basic account fields for comments
account_fields = {
"id": fields.String,
"name": fields.String,
"email": fields.String,
"avatar_url": AvatarUrlField,
}
# Comment mention fields
workflow_comment_mention_fields = {
"mentioned_user_id": fields.String,
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
"reply_id": fields.String,
}
# Comment reply fields
workflow_comment_reply_fields = {
"id": fields.String,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
}
# Basic comment fields (for list views)
workflow_comment_basic_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"reply_count": fields.Integer,
"mention_count": fields.Integer,
"participants": fields.List(fields.Nested(account_fields)),
}
# Detailed comment fields (for single comment view)
workflow_comment_detail_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
}
# Comment creation response fields (simplified)
workflow_comment_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Comment update response fields (simplified)
workflow_comment_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}
# Comment resolve response fields
workflow_comment_resolve_fields = {
"id": fields.String,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
}
# Reply creation response fields (simplified)
workflow_comment_reply_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Reply update response fields
workflow_comment_reply_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}

View File

@@ -0,0 +1,90 @@
"""Add workflow comments table
Revision ID: 227822d22895
Revises: 288345cd01d1
Create Date: 2025-08-22 17:26:15.255980
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '227822d22895'
down_revision = '288345cd01d1'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_comments',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('position_x', sa.Float(), nullable=False),
sa.Column('position_y', sa.Float(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('resolved_at', sa.DateTime(), nullable=True),
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
)
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_replies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
)
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_mentions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
)
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.drop_index('comment_mentions_user_idx')
batch_op.drop_index('comment_mentions_reply_idx')
batch_op.drop_index('comment_mentions_comment_idx')
op.drop_table('workflow_comment_mentions')
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.drop_index('comment_replies_created_at_idx')
batch_op.drop_index('comment_replies_comment_idx')
op.drop_table('workflow_comment_replies')
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.drop_index('workflow_comments_created_at_idx')
batch_op.drop_index('workflow_comments_app_idx')
op.drop_table('workflow_comments')
# ### end Alembic commands ###

View File

@@ -9,6 +9,11 @@ from .account import (
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .comment import (
WorkflowComment,
WorkflowCommentMention,
WorkflowCommentReply,
)
from .dataset import (
AppDatasetJoin,
Dataset,
@@ -197,6 +202,9 @@ __all__ = [
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowComment",
"WorkflowCommentMention",
"WorkflowCommentReply",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",

210
api/models/comment.py Normal file
View File

@@ -0,0 +1,210 @@
"""Workflow comment models."""
from datetime import datetime
from typing import Optional
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
class WorkflowComment(Base):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
since an app has only one draft workflow at a time and comments should persist
across workflow version changes.
Attributes:
id: Comment ID
tenant_id: Workspace ID
app_id: App ID (primary association, comments belong to apps)
position_x: X coordinate on canvas
position_y: Y coordinate on canvas
content: Comment content
created_by: Creator account ID
created_at: Creation time
updated_at: Last update time
resolved: Whether comment is resolved
resolved_at: Resolution time
resolved_by: Resolver account ID
"""
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
)
@property
def created_by_account(self):
"""Get creator account."""
if hasattr(self, "_created_by_account_cache"):
return self._created_by_account_cache
return db.session.get(Account, self.created_by)
def cache_created_by_account(self, account: Account | None) -> None:
"""Cache creator account to avoid extra queries."""
self._created_by_account_cache = account
@property
def resolved_by_account(self):
"""Get resolver account."""
if hasattr(self, "_resolved_by_account_cache"):
return self._resolved_by_account_cache
if self.resolved_by:
return db.session.get(Account, self.resolved_by)
return None
def cache_resolved_by_account(self, account: Account | None) -> None:
"""Cache resolver account to avoid extra queries."""
self._resolved_by_account_cache = account
@property
def reply_count(self):
"""Get reply count."""
return len(self.replies)
@property
def mention_count(self):
"""Get mention count."""
return len(self.mentions)
@property
def participants(self):
"""Get all participants (creator + repliers + mentioned users)."""
participant_ids = set()
# Add comment creator
participant_ids.add(self.created_by)
# Add reply creators
participant_ids.update(reply.created_by for reply in self.replies)
# Add mentioned users
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
# Get account objects
participants = []
for user_id in participant_ids:
account = db.session.get(Account, user_id)
if account:
participants.append(account)
return participants
class WorkflowCommentReply(Base):
"""Workflow comment reply model.
Attributes:
id: Reply ID
comment_id: Parent comment ID
content: Reply content
created_by: Creator account ID
created_at: Creation time
"""
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
"""Get creator account."""
if hasattr(self, "_created_by_account_cache"):
return self._created_by_account_cache
return db.session.get(Account, self.created_by)
def cache_created_by_account(self, account: Account | None) -> None:
"""Cache creator account to avoid extra queries."""
self._created_by_account_cache = account
class WorkflowCommentMention(Base):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
cannot access workflow canvas and commenting features.
Attributes:
id: Mention ID
comment_id: Parent comment ID
mentioned_user_id: Mentioned account ID
"""
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):
"""Get mentioned account."""
if hasattr(self, "_mentioned_user_account_cache"):
return self._mentioned_user_account_cache
return db.session.get(Account, self.mentioned_user_id)
def cache_mentioned_user_account(self, account: Account | None) -> None:
"""Cache mentioned account to avoid extra queries."""
self._mentioned_user_account_cache = account

View File

@@ -401,7 +401,7 @@ class Workflow(Base): # bug
:return: hash
"""
entity = {"graph": self.graph_dict, "features": self.features_dict}
entity = {"graph": self.graph_dict}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
@@ -1285,7 +1285,7 @@ class WorkflowDraftVariable(Base):
# which may differ from the original value's type. Typically, they are the same,
# but in cases where the structurally truncated value still exceeds the size limit,
# text slicing is applied, and the `value_type` is converted to `STRING`.
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# The variable's value serialized as a JSON string
#
@@ -1659,7 +1659,7 @@ class WorkflowDraftVariableFile(Base):
# The `value_type` field records the type of the original value.
value_type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=21),
EnumText(SegmentType, length=20),
nullable=False,
)

View File

@@ -21,6 +21,7 @@ dependencies = [
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent-websocket~=0.10.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-python-client==2.90.0",
@@ -72,6 +73,7 @@ dependencies = [
"pypdfium2==5.2.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"python-socketio~=5.13.0",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.1.0",

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
import json
from typing import TypedDict
from extensions.ext_redis import redis_client
SESSION_STATE_TTL_SECONDS = 3600
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
WS_SID_MAP_PREFIX = "ws_sid_map:"
class WorkflowSessionInfo(TypedDict):
user_id: str
username: str
avatar: str | None
sid: str
connected_at: int
class SidMapping(TypedDict):
workflow_id: str
user_id: str
class WorkflowCollaborationRepository:
def __init__(self) -> None:
self._redis = redis_client
def __repr__(self) -> str:
return f"{self.__class__.__name__}(redis_client={self._redis})"
@staticmethod
def workflow_key(workflow_id: str) -> str:
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
@staticmethod
def leader_key(workflow_id: str) -> str:
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
@staticmethod
def sid_key(sid: str) -> str:
return f"{WS_SID_MAP_PREFIX}{sid}"
@staticmethod
def _decode(value: str | bytes | None) -> str | None:
if value is None:
return None
if isinstance(value, bytes):
return value.decode("utf-8")
return value
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
workflow_key = self.workflow_key(workflow_id)
sid_key = self.sid_key(sid)
if self._redis.exists(workflow_key):
self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
if self._redis.exists(sid_key):
self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS)
def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None:
workflow_key = self.workflow_key(workflow_id)
self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info))
self._redis.set(
self.sid_key(session_info["sid"]),
json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}),
ex=SESSION_STATE_TTL_SECONDS,
)
self.refresh_session_state(workflow_id, session_info["sid"])
def get_sid_mapping(self, sid: str) -> SidMapping | None:
raw = self._redis.get(self.sid_key(sid))
if not raw:
return None
value = self._decode(raw)
if not value:
return None
try:
return json.loads(value)
except (TypeError, json.JSONDecodeError):
return None
def delete_session(self, workflow_id: str, sid: str) -> None:
self._redis.hdel(self.workflow_key(workflow_id), sid)
self._redis.delete(self.sid_key(sid))
def session_exists(self, workflow_id: str, sid: str) -> bool:
return bool(self._redis.hexists(self.workflow_key(workflow_id), sid))
def sid_mapping_exists(self, sid: str) -> bool:
return bool(self._redis.exists(self.sid_key(sid)))
def get_session_sids(self, workflow_id: str) -> list[str]:
raw_sids = self._redis.hkeys(self.workflow_key(workflow_id))
decoded_sids: list[str] = []
for sid in raw_sids:
decoded = self._decode(sid)
if decoded:
decoded_sids.append(decoded)
return decoded_sids
def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
sessions_json = self._redis.hgetall(self.workflow_key(workflow_id))
users: list[WorkflowSessionInfo] = []
for session_info_json in sessions_json.values():
value = self._decode(session_info_json)
if not value:
continue
try:
session_info = json.loads(value)
except (TypeError, json.JSONDecodeError):
continue
if not isinstance(session_info, dict):
continue
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
continue
users.append(
{
"user_id": str(session_info["user_id"]),
"username": str(session_info["username"]),
"avatar": session_info.get("avatar"),
"sid": str(session_info["sid"]),
"connected_at": int(session_info.get("connected_at") or 0),
}
)
return users
def get_current_leader(self, workflow_id: str) -> str | None:
raw = self._redis.get(self.leader_key(workflow_id))
return self._decode(raw)
def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool:
return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS))
def set_leader(self, workflow_id: str, sid: str) -> None:
self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS)
def delete_leader(self, workflow_id: str) -> None:
self._redis.delete(self.leader_key(workflow_id))
def expire_leader(self, workflow_id: str) -> None:
self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS)

View File

@@ -13,10 +13,11 @@ import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
@@ -73,6 +74,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.file_service import FileService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
@@ -1162,6 +1164,7 @@ class DocumentService:
Document.archived.is_(True),
),
}
DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip"
@classmethod
def normalize_display_status(cls, status: str | None) -> str | None:
@@ -1288,6 +1291,143 @@ class DocumentService:
else:
return None
@staticmethod
def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]:
"""Fetch documents for a dataset in a single batch query."""
if not document_ids:
return []
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
# Fetch all requested documents in one query to avoid N+1 lookups.
documents: Sequence[Document] = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.id.in_(document_id_list),
)
).all()
return documents
@staticmethod
def get_document_download_url(document: Document) -> str:
"""
Return a signed download URL for an upload-file document.
"""
upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
@staticmethod
def prepare_document_batch_download_zip(
*,
dataset_id: str,
document_ids: Sequence[str],
tenant_id: str,
current_user: Account,
) -> tuple[list[UploadFile], str]:
"""
Resolve upload files for batch ZIP downloads and generate a client-visible filename.
"""
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except NoPermissionError as e:
raise Forbidden(str(e))
upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download(
dataset_id=dataset_id,
document_ids=document_ids,
tenant_id=tenant_id,
)
upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids]
download_name = DocumentService._generate_document_batch_download_zip_filename()
return upload_files, download_name
@staticmethod
def _generate_document_batch_download_zip_filename() -> str:
"""
Generate a random attachment filename for the batch download ZIP.
"""
return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}"
@staticmethod
def _get_upload_file_id_for_upload_file_document(
document: Document,
*,
invalid_source_message: str,
missing_file_message: str,
) -> str:
"""
Normalize and validate `Document -> UploadFile` linkage for download flows.
"""
if document.data_source_type != "upload_file":
raise NotFound(invalid_source_message)
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
upload_file_id: str | None = data_source_info.get("upload_file_id")
if not upload_file_id:
raise NotFound(missing_file_message)
return str(upload_file_id)
@staticmethod
def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile:
"""
Load the `UploadFile` row for an upload-file document.
"""
upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="Document does not have an uploaded file to download.",
missing_file_message="Uploaded file not found.",
)
upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
upload_file = upload_files_by_id.get(upload_file_id)
if not upload_file:
raise NotFound("Uploaded file not found.")
return upload_file
@staticmethod
def _get_upload_files_by_document_id_for_zip_download(
*,
dataset_id: str,
document_ids: Sequence[str],
tenant_id: str,
) -> dict[str, UploadFile]:
"""
Batch load upload files keyed by document id for ZIP downloads.
"""
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list)
documents_by_id: dict[str, Document] = {str(document.id): document for document in documents}
missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys())
if missing_document_ids:
raise NotFound("Document not found.")
upload_file_ids: list[str] = []
upload_file_ids_by_document_id: dict[str, str] = {}
for document_id, document in documents_by_id.items():
if document.tenant_id != tenant_id:
raise Forbidden("No permission.")
upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.",
missing_file_message="Only uploaded-file documents can be downloaded as ZIP.",
)
upload_file_ids.append(upload_file_id)
upload_file_ids_by_document_id[document_id] = upload_file_id
upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
if missing_upload_file_ids:
raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
return {
document_id: upload_files_by_id[upload_file_id]
for document_id, upload_file_id in upload_file_ids_by_document_id.items()
}
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()

View File

@@ -161,6 +161,7 @@ class SystemFeatureModel(BaseModel):
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
enable_collaboration_mode: bool = False
is_allow_register: bool = False
is_allow_create_workspace: bool = False
is_email_setup: bool = False
@@ -222,6 +223,7 @@ class FeatureService:
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""

View File

@@ -2,7 +2,11 @@ import base64
import hashlib
import os
import uuid
from collections.abc import Iterator, Sequence
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal, Union
from zipfile import ZIP_DEFLATED, ZipFile
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
@@ -17,6 +21,7 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@@ -167,6 +172,9 @@ class FileService:
return upload_file
def get_file_preview(self, file_id: str):
"""
Return a short text preview extracted from a document file.
"""
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
@@ -253,3 +261,101 @@ class FileService:
return
storage.delete(upload_file.key)
session.delete(upload_file)
@staticmethod
def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
"""
Fetch `UploadFile` rows for a tenant in a single batch query.
This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`.
"""
if not upload_file_ids:
return {}
# Normalize and deduplicate ids before using them in the IN clause.
upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids]
unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
# Fetch upload files in one query for efficient batch access.
upload_files: Sequence[UploadFile] = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == tenant_id,
UploadFile.id.in_(unique_upload_file_ids),
)
).all()
return {str(upload_file.id): upload_file for upload_file in upload_files}
@staticmethod
def _sanitize_zip_entry_name(name: str) -> str:
"""
Sanitize a ZIP entry name to avoid path traversal and weird separators.
We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data)
could still contain unsafe names.
"""
# Drop any directory components and prevent empty names.
base = os.path.basename(name).strip() or "file"
# ZIP uses forward slashes as separators; remove any residual separator characters.
return base.replace("/", "_").replace("\\", "_")
@staticmethod
def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str:
"""
Return a unique ZIP entry name, inserting suffixes before the extension.
"""
# Keep the original name when it's not already used.
if original_name not in used_names:
return original_name
# Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt").
stem, extension = os.path.splitext(original_name)
suffix = 1
while True:
candidate = f"{stem} ({suffix}){extension}"
if candidate not in used_names:
return candidate
suffix += 1
@staticmethod
@contextmanager
def build_upload_files_zip_tempfile(
*,
upload_files: Sequence[UploadFile],
) -> Iterator[str]:
"""
Build a ZIP from `UploadFile`s and yield a tempfile path.
We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug
streams responses. The caller is expected to keep this context open until the response is fully sent, then
close it (e.g., via `response.call_on_close(...)`) to delete the tempfile.
"""
used_names: set[str] = set()
# Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it.
tmp_path: str | None = None
try:
with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp:
tmp_path = tmp.name
with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf:
for upload_file in upload_files:
# Ensure the entry name is safe and unique.
safe_name = FileService._sanitize_zip_entry_name(upload_file.name)
arcname = FileService._dedupe_zip_entry_name(safe_name, used_names)
used_names.add(arcname)
# Stream file bytes from storage into the ZIP entry.
with zf.open(arcname, "w") as entry:
for chunk in storage.load(upload_file.key, stream=True):
entry.write(chunk)
# Flush so `send_file(path, ...)` can re-open it safely on all platforms.
tmp.flush()
assert tmp_path is not None
yield tmp_path
finally:
# Remove the temp file when the context is closed (typically after the response finishes streaming).
if tmp_path is not None:
with suppress(FileNotFoundError):
os.remove(tmp_path)

View File

@@ -7,7 +7,6 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
from configs import dify_config
from core.file.models import File
from core.model_runtime.entities import PromptMessage
from core.variables.segments import (
ArrayFileSegment,
ArraySegment,
@@ -288,10 +287,6 @@ class VariableTruncator(BaseTruncator):
if isinstance(item, File):
truncated_value.append(item)
continue
# Handle PromptMessage types - convert to dict for truncation
if isinstance(item, PromptMessage):
truncated_value.append(item)
continue
if i >= target_length:
return _PartResult(truncated_value, used_size, True)
if i > 0:

View File

@@ -163,29 +163,3 @@ class WorkflowScheduleCFSPlanEntity(BaseModel):
schedule_strategy: Strategy
granularity: int = Field(default=-1) # -1 means infinite
# ========== Mention Graph Entities ==========
class MentionParameterSchema(BaseModel):
"""Schema for the parameter to be extracted from mention context."""
name: str = Field(description="Parameter name (e.g., 'query')")
type: str = Field(default="string", description="Parameter type (e.g., 'string', 'number')")
description: str = Field(default="", description="Parameter description for LLM")
class MentionGraphRequest(BaseModel):
"""Request payload for generating mention graph."""
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
parameter_key: str = Field(description="Key of the parameter being extracted")
context_source: list[str] = Field(description="Variable selector for the context source")
parameter_schema: MentionParameterSchema = Field(description="Schema of the parameter to extract")
class MentionGraphResponse(BaseModel):
"""Response containing the generated mention graph."""
graph: Mapping[str, Any] = Field(description="Complete graph structure with nodes, edges, viewport")

View File

@@ -1,143 +0,0 @@
"""
Service for generating Mention LLM node graph structures.
This service creates graph structures containing LLM nodes configured for
extracting values from list[PromptMessage] variables.
"""
from typing import Any
from sqlalchemy.orm import Session
from core.model_runtime.entities import LLMMode
from core.workflow.enums import NodeType
from services.model_provider_service import ModelProviderService
from services.workflow.entities import MentionGraphRequest, MentionGraphResponse, MentionParameterSchema
class MentionGraphService:
"""Service for generating Mention LLM node graph structures."""
def __init__(self, session: Session):
self._session = session
def generate_mention_node_id(self, node_id: str, parameter_name: str) -> str:
"""Generate mention node ID following the naming convention.
Format: {node_id}_ext_{parameter_name}
"""
return f"{node_id}_ext_{parameter_name}"
def generate_mention_graph(self, tenant_id: str, request: MentionGraphRequest) -> MentionGraphResponse:
"""Generate a complete graph structure containing a Mention LLM node.
Args:
tenant_id: The tenant ID for fetching default model config
request: The mention graph generation request
Returns:
Complete graph structure with nodes, edges, and viewport
"""
node_id = self.generate_mention_node_id(request.parent_node_id, request.parameter_key)
model_config = self._get_default_model_config(tenant_id)
node = self._build_mention_llm_node(
node_id=node_id,
parent_node_id=request.parent_node_id,
context_source=request.context_source,
parameter_schema=request.parameter_schema,
model_config=model_config,
)
graph = {
"nodes": [node],
"edges": [],
"viewport": {},
}
return MentionGraphResponse(graph=graph)
def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]:
"""Get the default LLM model configuration for the tenant."""
model_provider_service = ModelProviderService()
default_model = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type="llm",
)
if default_model:
return {
"provider": default_model.provider.provider,
"name": default_model.model,
"mode": LLMMode.CHAT.value,
"completion_params": {},
}
# Fallback to empty config if no default model is configured
return {
"provider": "",
"name": "",
"mode": LLMMode.CHAT.value,
"completion_params": {},
}
def _build_mention_llm_node(
self,
*,
node_id: str,
parent_node_id: str,
context_source: list[str],
parameter_schema: MentionParameterSchema,
model_config: dict[str, Any],
) -> dict[str, Any]:
"""Build the Mention LLM node structure.
The node uses:
- $context in prompt_template to reference the PromptMessage list
- structured_output for extracting the specific parameter
- parent_node_id to associate with the parent node
"""
prompt_template = [
{
"role": "system",
"text": "Extract the required parameter value from the conversation context above.",
},
{"$context": context_source},
{"role": "user", "text": ""},
]
structured_output = {
"schema": {
"type": "object",
"properties": {
parameter_schema.name: {
"type": parameter_schema.type,
"description": parameter_schema.description,
}
},
"required": [parameter_schema.name],
"additionalProperties": False,
}
}
return {
"id": node_id,
"position": {"x": 0, "y": 0},
"data": {
"type": NodeType.LLM.value,
"title": f"Mention: {parameter_schema.name}",
"desc": f"Extract {parameter_schema.name} from conversation context",
"parent_node_id": parent_node_id,
"model": model_config,
"prompt_template": prompt_template,
"context": {
"enabled": False,
"variable_selector": None,
},
"vision": {
"enabled": False,
},
"memory": None,
"structured_output_enabled": True,
"structured_output": structured_output,
},
}

View File

@@ -0,0 +1,196 @@
from __future__ import annotations
import logging
import time
from collections.abc import Mapping
from models.account import Account
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo
class WorkflowCollaborationService:
def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None:
self._repository = repository
self._socketio = socketio
def __repr__(self) -> str:
return f"{self.__class__.__name__}(repository={self._repository})"
def save_session(self, sid: str, user: Account) -> None:
self._socketio.save_session(
sid,
{
"user_id": user.id,
"username": user.name,
"avatar": user.avatar,
},
)
def register_session(self, workflow_id: str, sid: str) -> tuple[str, bool] | None:
session = self._socketio.get_session(sid)
user_id = session.get("user_id")
if not user_id:
return None
session_info: WorkflowSessionInfo = {
"user_id": str(user_id),
"username": str(session.get("username", "Unknown")),
"avatar": session.get("avatar"),
"sid": sid,
"connected_at": int(time.time()),
}
self._repository.set_session_info(workflow_id, session_info)
leader_sid = self.get_or_set_leader(workflow_id, sid)
is_leader = leader_sid == sid
self._socketio.enter_room(sid, workflow_id)
self.broadcast_online_users(workflow_id)
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
return str(user_id), is_leader
def disconnect_session(self, sid: str) -> None:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return
workflow_id = mapping["workflow_id"]
self._repository.delete_session(workflow_id, sid)
self.handle_leader_disconnect(workflow_id, sid)
self.broadcast_online_users(workflow_id)
def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return {"msg": "unauthorized"}, 401
workflow_id = mapping["workflow_id"]
user_id = mapping["user_id"]
self.refresh_session_state(workflow_id, sid)
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type:
return {"msg": "invalid event type"}, 400
self._socketio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=workflow_id,
skip_sid=sid,
)
return {"msg": "event_broadcasted"}, 200
def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return {"msg": "unauthorized"}, 401
workflow_id = mapping["workflow_id"]
self.refresh_session_state(workflow_id, sid)
self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "graph_update_broadcasted"}, 200
def get_or_set_leader(self, workflow_id: str, sid: str) -> str:
current_leader = self._repository.get_current_leader(workflow_id)
if current_leader:
if self.is_session_active(workflow_id, current_leader):
return current_leader
self._repository.delete_session(workflow_id, current_leader)
self._repository.delete_leader(workflow_id)
was_set = self._repository.set_leader_if_absent(workflow_id, sid)
if was_set:
if current_leader:
self.broadcast_leader_change(workflow_id, sid)
return sid
current_leader = self._repository.get_current_leader(workflow_id)
if current_leader:
return current_leader
return sid
def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None:
current_leader = self._repository.get_current_leader(workflow_id)
if not current_leader:
return
if current_leader != disconnected_sid:
return
session_sids = self._repository.get_session_sids(workflow_id)
if session_sids:
new_leader_sid = session_sids[0]
self._repository.set_leader(workflow_id, new_leader_sid)
self.broadcast_leader_change(workflow_id, new_leader_sid)
else:
self._repository.delete_leader(workflow_id)
def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str) -> None:
for sid in self._repository.get_session_sids(workflow_id):
try:
is_leader = sid == new_leader_sid
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
except Exception:
logging.exception("Failed to emit leader status to session %s", sid)
def get_current_leader(self, workflow_id: str) -> str | None:
return self._repository.get_current_leader(workflow_id)
def broadcast_online_users(self, workflow_id: str) -> None:
users = self._repository.list_sessions(workflow_id)
users.sort(key=lambda x: x.get("connected_at") or 0)
leader_sid = self.get_current_leader(workflow_id)
self._socketio.emit(
"online_users",
{"workflow_id": workflow_id, "users": users, "leader": leader_sid},
room=workflow_id,
)
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
self._repository.refresh_session_state(workflow_id, sid)
self._ensure_leader(workflow_id, sid)
def _ensure_leader(self, workflow_id: str, sid: str) -> None:
current_leader = self._repository.get_current_leader(workflow_id)
if current_leader and self.is_session_active(workflow_id, current_leader):
self._repository.expire_leader(workflow_id)
return
if current_leader:
self._repository.delete_leader(workflow_id)
self._repository.set_leader(workflow_id, sid)
self.broadcast_leader_change(workflow_id, sid)
def is_session_active(self, workflow_id: str, sid: str) -> bool:
if not sid:
return False
try:
if not self._socketio.manager.is_connected(sid, "/"):
return False
except AttributeError:
return False
if not self._repository.session_exists(workflow_id, sid):
return False
if not self._repository.sid_mapping_exists(sid):
return False
return True

View File

@@ -0,0 +1,345 @@
import logging
from collections.abc import Sequence
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, selectinload
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
from models.account import Account
logger = logging.getLogger(__name__)
class WorkflowCommentService:
"""Service for managing workflow comments."""
@staticmethod
def _validate_content(content: str) -> None:
if len(content.strip()) == 0:
raise ValueError("Comment content cannot be empty")
if len(content) > 1000:
raise ValueError("Comment content cannot exceed 1000 characters")
@staticmethod
def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]:
"""Get all comments for a workflow."""
with Session(db.engine) as session:
# Get all comments with eager loading
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
.order_by(desc(WorkflowComment.created_at))
)
comments = session.scalars(stmt).all()
# Batch preload all Account objects to avoid N+1 queries
WorkflowCommentService._preload_accounts(session, comments)
return comments
@staticmethod
def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None:
"""Batch preload Account objects for comments, replies, and mentions."""
# Collect all user IDs
user_ids: set[str] = set()
for comment in comments:
user_ids.add(comment.created_by)
if comment.resolved_by:
user_ids.add(comment.resolved_by)
user_ids.update(reply.created_by for reply in comment.replies)
user_ids.update(mention.mentioned_user_id for mention in comment.mentions)
if not user_ids:
return
# Batch query all accounts
accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
account_map = {str(account.id): account for account in accounts}
# Cache accounts on objects
for comment in comments:
comment.cache_created_by_account(account_map.get(comment.created_by))
comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None)
for reply in comment.replies:
reply.cache_created_by_account(account_map.get(reply.created_by))
for mention in comment.mentions:
mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id))
@staticmethod
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment:
"""Get a specific comment."""
def _get_comment(session: Session) -> WorkflowComment:
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Preload accounts to avoid N+1 queries
WorkflowCommentService._preload_accounts(session, [comment])
return comment
if session is not None:
return _get_comment(session)
else:
with Session(db.engine, expire_on_commit=False) as session:
return _get_comment(session)
@staticmethod
def create_comment(
tenant_id: str,
app_id: str,
created_by: str,
content: str,
position_x: float,
position_y: float,
mentioned_user_ids: list[str] | None = None,
) -> dict:
"""Create a new workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine) as session:
comment = WorkflowComment(
tenant_id=tenant_id,
app_id=app_id,
position_x=position_x,
position_y=position_y,
content=content,
created_by=created_by,
)
session.add(comment)
session.flush() # Get the comment ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention, not reply mention
mentioned_user_id=user_id,
)
session.add(mention)
session.commit()
# Return only what we need - id and created_at
return {"id": comment.id, "created_at": comment.created_at}
@staticmethod
def update_comment(
tenant_id: str,
app_id: str,
comment_id: str,
user_id: str,
content: str,
position_x: float | None = None,
position_y: float | None = None,
mentioned_user_ids: list[str] | None = None,
) -> dict:
"""Update a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Get comment with validation
stmt = select(WorkflowComment).where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Only the creator can update the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can update it")
# Update comment fields
comment.content = content
if position_x is not None:
comment.position_x = position_x
if position_y is not None:
comment.position_y = position_y
# Update mentions - first remove existing mentions for this comment only (not replies)
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(
WorkflowCommentMention.comment_id == comment.id,
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add new mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention
mentioned_user_id=user_id_str,
)
session.add(mention)
session.commit()
return {"id": comment.id, "updated_at": comment.updated_at}
@staticmethod
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
"""Delete a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
# Only the creator can delete the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can delete it")
# Delete associated mentions (both comment and reply mentions)
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
).all()
for mention in mentions:
session.delete(mention)
# Delete associated replies
replies = session.scalars(
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
).all()
for reply in replies:
session.delete(reply)
session.delete(comment)
session.commit()
@staticmethod
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
"""Resolve a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
if comment.resolved:
return comment
comment.resolved = True
comment.resolved_at = naive_utc_now()
comment.resolved_by = user_id
session.commit()
return comment
@staticmethod
def create_reply(
comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None
) -> dict:
"""Add a reply to a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Check if comment exists
comment = session.get(WorkflowComment, comment_id)
if not comment:
raise NotFound("Comment not found")
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
session.add(reply)
session.flush() # Get the reply ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
# Create mention linking to specific reply
mention = WorkflowCommentMention(
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
)
session.add(mention)
session.commit()
return {"id": reply.id, "created_at": reply.created_at}
@staticmethod
def update_reply(reply_id: str, user_id: str, content: str, mentioned_user_ids: list[str] | None = None) -> dict:
"""Update a comment reply."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can update the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can update it")
reply.content = content
# Update mentions - first remove existing mentions for this reply
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
)
session.add(mention)
session.commit()
session.refresh(reply) # Refresh to get updated timestamp
return {"id": reply.id, "updated_at": reply.updated_at}
@staticmethod
def delete_reply(reply_id: str, user_id: str) -> None:
"""Delete a comment reply."""
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can delete the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can delete it")
# Delete associated mentions first
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
).all()
for mention in mentions:
session.delete(mention)
session.delete(reply)
session.commit()
@staticmethod
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
"""Validate that a comment belongs to the specified tenant and app."""
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)

View File

@@ -249,6 +249,78 @@ class WorkflowService:
# return draft workflow
return workflow
def update_draft_workflow_environment_variables(
self,
*,
app_model: App,
environment_variables: Sequence[VariableBase],
account: Account,
):
"""
Update draft workflow environment variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.environment_variables = environment_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_conversation_variables(
self,
*,
app_model: App,
conversation_variables: Sequence[VariableBase],
account: Account,
):
"""
Update draft workflow conversation variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.conversation_variables = conversation_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_features(
self,
*,
app_model: App,
features: dict,
account: Account,
):
"""
Update draft workflow features
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def publish_workflow(
self,
*,

View File

@@ -1,181 +0,0 @@
app:
description: ''
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: file output schema
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
version: null
kind: app
version: 0.5.0
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- remote_url
- local_file
enabled: true
fileUploadConfig:
attachment_image_file_size_limit: 2
audio_file_size_limit: 50
batch_count_limit: 5
file_size_limit: 15
file_upload_limit: 10
image_file_batch_limit: 10
image_file_size_limit: 10
single_chunk_attachment_limit: 10
video_file_size_limit: 100
workflow_file_upload_limit: 10
number_limits: 3
opening_statement: ''
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
sourceType: start
targetType: llm
id: 1768292241666-llm
source: '1768292241666'
sourceHandle: source
target: llm
targetHandle: target
type: custom
- data:
sourceType: llm
targetType: answer
id: llm-answer
source: llm
sourceHandle: source
target: answer
targetHandle: target
type: custom
nodes:
- data:
selected: false
title: User Input
type: start
variables: []
height: 73
id: '1768292241666'
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
memory:
query_prompt_template: '{{#sys.query#}}
{{#sys.files#}}'
role_prefix:
assistant: ''
user: ''
window:
enabled: false
size: 10
model:
completion_params:
temperature: 0.7
mode: chat
name: gpt-4o-mini
provider: langgenius/openai/openai
prompt_template:
- id: e30d75d7-7d85-49ec-be3c-3baf7f6d3c5a
role: system
text: ''
selected: false
structured_output:
schema:
additionalProperties: false
properties:
image:
description: File ID (UUID) of the selected image
format: dify-file-ref
type: string
required:
- image
type: object
structured_output_enabled: true
title: LLM
type: llm
vision:
configs:
detail: high
variable_selector:
- sys
- files
enabled: true
height: 88
id: llm
position:
x: 380
y: 282
positionAbsolute:
x: 380
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
answer: '{{#llm.structured_output.image#}}'
selected: false
title: Answer
type: answer
variables: []
height: 103
id: answer
position:
x: 680
y: 282
positionAbsolute:
x: 680
y: 282
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -149
y: 97.5
zoom: 1
rag_pipeline_variables: []

View File

@@ -1,307 +0,0 @@
app:
description: Test for variable extraction feature
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: pav-test-extraction
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
version: null
kind: app
version: 0.5.0
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- local_file
- remote_url
enabled: false
image:
enabled: false
number_limits: 3
transfer_methods:
- local_file
- remote_url
number_limits: 3
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
sourceType: start
targetType: llm
id: 1767773675796-llm
source: '1767773675796'
sourceHandle: source
target: llm
targetHandle: target
type: custom
- data:
isInIteration: false
isInLoop: false
sourceType: llm
targetType: tool
id: llm-source-1767773709491-target
source: llm
sourceHandle: source
target: '1767773709491'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: answer
id: tool-source-answer-target
source: '1767773709491'
sourceHandle: source
target: answer
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
selected: false
title: User Input
type: start
variables: []
height: 73
id: '1767773675796'
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
memory:
mode: node
query_prompt_template: '{{#sys.query#}}'
role_prefix:
assistant: ''
user: ''
window:
enabled: true
size: 10
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
prompt_template:
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
role: system
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
请与用户进行对话,了解他们的搜索需求。
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
'
selected: false
title: LLM
type: llm
vision:
enabled: false
height: 88
id: llm
position:
x: 380
y: 282
positionAbsolute:
x: 380
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: used for searching
ja_JP: used for searching
pt_BR: used for searching
zh_Hans: 用于搜索网页内容
label:
en_US: Query string
ja_JP: Query string
pt_BR: Query string
zh_Hans: 查询语句
llm_description: key words for searching
max: null
min: null
name: query
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
params:
query: ''
plugin_id: langgenius/google
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
provider_id: langgenius/google/google
provider_name: langgenius/google/google
provider_type: builtin
selected: false
title: GoogleSearch
tool_configurations: {}
tool_description: A tool for performing a Google SERP search and extracting
snippets and webpages.Input should be a search query.
tool_label: GoogleSearch
tool_name: google_search
tool_node_version: '2'
tool_parameters:
query:
type: mention
value: '{{@llm.context@}}请从对话历史中提取用户想要搜索的关键词,只返回关键词本身'
mention_config:
extractor_node_id: 1767773709491_ext_query
output_selector:
- structured_output
- query
null_strategy: use_default
default_value: ''
type: tool
height: 52
id: '1767773709491'
position:
x: 682
y: 282
positionAbsolute:
x: 682
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
model:
completion_params:
temperature: 0.7
mode: chat
name: gpt-4o-mini
provider: langgenius/openai/openai
parent_node_id: '1767773709491'
prompt_template:
- $context:
- llm
- context
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
- id: 18ba6710-77f5-47f4-b144-9191833bb547
role: user
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
selected: false
structured_output:
schema:
additionalProperties: false
properties:
query:
description: 搜索的关键词
type: string
required:
- query
type: object
structured_output_enabled: true
title: 提取搜索关键词
type: llm
vision:
enabled: false
height: 88
id: 1767773709491_ext_query
position:
x: 531
y: 382
positionAbsolute:
x: 531
y: 382
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
answer: '搜索结果:
{{#1767773709491.text#}}
'
selected: false
title: Answer
type: answer
height: 103
id: answer
position:
x: 984
y: 282
positionAbsolute:
x: 984
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -151
y: 123
zoom: 1
rag_pipeline_variables: []

View File

@@ -38,7 +38,7 @@ os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage"
os.environ.setdefault("STORAGE_TYPE", "opendal")
os.environ.setdefault("OPENDAL_SCHEME", "fs")
_CACHED_APP = create_app()
_SIO_APP, _CACHED_APP = create_app()
@pytest.fixture

View File

@@ -364,7 +364,7 @@ def _create_app_with_containers() -> Flask:
# Create and configure the Flask application
logger.info("Initializing Flask application...")
app = create_app()
sio_app, app = create_app()
logger.info("Flask application created successfully")
# Initialize database schema

View File

@@ -268,6 +268,7 @@ class TestFeatureService:
mock_config.ENABLE_EMAIL_CODE_LOGIN = True
mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True
mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False
mock_config.ENABLE_COLLABORATION_MODE = True
mock_config.ALLOW_REGISTER = False
mock_config.ALLOW_CREATE_WORKSPACE = False
mock_config.MAIL_TYPE = "smtp"
@@ -292,6 +293,7 @@ class TestFeatureService:
# Verify authentication settings
assert result.enable_email_code_login is True
assert result.enable_email_password_login is False
assert result.enable_collaboration_mode is True
assert result.is_allow_register is False
assert result.is_allow_create_workspace is False
@@ -341,6 +343,7 @@ class TestFeatureService:
mock_config.ENABLE_EMAIL_CODE_LOGIN = True
mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True
mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False
mock_config.ENABLE_COLLABORATION_MODE = False
mock_config.ALLOW_REGISTER = True
mock_config.ALLOW_CREATE_WORKSPACE = True
mock_config.MAIL_TYPE = "smtp"
@@ -362,6 +365,7 @@ class TestFeatureService:
assert result.enable_email_code_login is True
assert result.enable_email_password_login is True
assert result.enable_social_oauth_login is False
assert result.enable_collaboration_mode is False
assert result.is_allow_register is True
assert result.is_allow_create_workspace is True
assert result.is_email_setup is True

View File

@@ -1,254 +0,0 @@
"""
Unit tests for XSS prevention in App payloads.
This test module validates that HTML tags, JavaScript, and other potentially
dangerous content are rejected in App names and descriptions.
"""
import pytest
from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload
class TestXSSPreventionUnit:
"""Unit tests for XSS prevention in App payloads."""
def test_create_app_valid_names(self):
"""Test CreateAppPayload with valid app names."""
# Normal app names should be valid
valid_names = [
"My App",
"Test App 123",
"App with - dash",
"App with _ underscore",
"App with + plus",
"App with () parentheses",
"App with [] brackets",
"App with {} braces",
"App with ! exclamation",
"App with @ at",
"App with # hash",
"App with $ dollar",
"App with % percent",
"App with ^ caret",
"App with & ampersand",
"App with * asterisk",
"Unicode: 测试应用",
"Emoji: 🤖",
"Mixed: Test 测试 123",
]
for name in valid_names:
payload = CreateAppPayload(
name=name,
mode="chat",
)
assert payload.name == name
def test_create_app_xss_script_tags(self):
"""Test CreateAppPayload rejects script tags."""
xss_payloads = [
"<script>alert(document.cookie)</script>",
"<Script>alert(1)</Script>",
"<SCRIPT>alert('XSS')</SCRIPT>",
"<script>alert(String.fromCharCode(88,83,83))</script>",
"<script src='evil.js'></script>",
"<script>document.location='http://evil.com'</script>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_iframe_tags(self):
"""Test CreateAppPayload rejects iframe tags."""
xss_payloads = [
"<iframe src='evil.com'></iframe>",
"<Iframe srcdoc='<script>alert(1)</script>'></iframe>",
"<IFRAME src='javascript:alert(1)'></iframe>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_javascript_protocol(self):
"""Test CreateAppPayload rejects javascript: protocol."""
xss_payloads = [
"javascript:alert(1)",
"JAVASCRIPT:alert(1)",
"JavaScript:alert(document.cookie)",
"javascript:void(0)",
"javascript://comment%0Aalert(1)",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_svg_onload(self):
"""Test CreateAppPayload rejects SVG with onload."""
xss_payloads = [
"<svg onload=alert(1)>",
"<SVG ONLOAD=alert(1)>",
"<svg/x/onload=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_event_handlers(self):
"""Test CreateAppPayload rejects HTML event handlers."""
xss_payloads = [
"<div onclick=alert(1)>",
"<img onerror=alert(1)>",
"<body onload=alert(1)>",
"<input onfocus=alert(1)>",
"<a onmouseover=alert(1)>",
"<DIV ONCLICK=alert(1)>",
"<img src=x onerror=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_object_embed(self):
"""Test CreateAppPayload rejects object and embed tags."""
xss_payloads = [
"<object data='evil.swf'></object>",
"<embed src='evil.swf'>",
"<OBJECT data='javascript:alert(1)'></OBJECT>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_link_javascript(self):
"""Test CreateAppPayload rejects link tags with javascript."""
xss_payloads = [
"<link href='javascript:alert(1)'>",
"<LINK HREF='javascript:alert(1)'>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_in_description(self):
"""Test CreateAppPayload rejects XSS in description."""
xss_descriptions = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for description in xss_descriptions:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(
name="Valid Name",
mode="chat",
description=description,
)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_valid_descriptions(self):
"""Test CreateAppPayload with valid descriptions."""
valid_descriptions = [
"A simple description",
"Description with < and > symbols",
"Description with & ampersand",
"Description with 'quotes' and \"double quotes\"",
"Description with / slashes",
"Description with \\ backslashes",
"Description with ; semicolons",
"Unicode: 这是一个描述",
"Emoji: 🎉🚀",
]
for description in valid_descriptions:
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=description,
)
assert payload.description == description
def test_create_app_none_description(self):
"""Test CreateAppPayload with None description."""
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=None,
)
assert payload.description is None
def test_update_app_xss_prevention(self):
"""Test UpdateAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
UpdateAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_update_app_valid_names(self):
"""Test UpdateAppPayload with valid names."""
payload = UpdateAppPayload(name="Valid Updated Name")
assert payload.name == "Valid Updated Name"
def test_copy_app_xss_prevention(self):
"""Test CopyAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
CopyAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_copy_app_valid_names(self):
"""Test CopyAppPayload with valid names."""
payload = CopyAppPayload(name="Valid Copy Name")
assert payload.name == "Valid Copy Name"
def test_copy_app_none_name(self):
"""Test CopyAppPayload with None name (should be allowed)."""
payload = CopyAppPayload(name=None)
assert payload.name is None
def test_edge_case_angle_brackets_content(self):
"""Test that angle brackets with actual content are rejected."""
# Angle brackets without valid HTML-like patterns should be checked
# The regex pattern <.*?on\w+\s*= should catch event handlers
# But let's verify other patterns too
# Valid: angle brackets used as symbols (not matched by our patterns)
# Our patterns specifically look for dangerous constructs
# Invalid: actual HTML tags with event handlers
invalid_names = [
"<div onclick=xss>",
"<img src=x onerror=alert(1)>",
]
for name in invalid_names:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()

View File

@@ -0,0 +1,430 @@
"""
Unit tests for the dataset document download endpoint.
These tests validate that the controller returns a signed download URL for
upload-file documents, and rejects unsupported or missing file cases.
"""
from __future__ import annotations
import importlib
import sys
from collections import UserDict
from io import BytesIO
from types import SimpleNamespace
from typing import Any
from zipfile import ZipFile
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
@pytest.fixture
def app() -> Flask:
"""Create a minimal Flask app for request-context based controller tests."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def datasets_document_module(monkeypatch: pytest.MonkeyPatch):
"""
Reload `controllers.console.datasets.datasets_document` with lightweight decorators.
We patch auth / setup / rate-limit decorators to no-ops so we can unit test the
controller logic without requiring the full console stack.
"""
from controllers.console import console_ns, wraps
from libs import login
def _noop(func): # type: ignore[no-untyped-def]
return func
# Bypass login/setup/account checks in unit tests.
monkeypatch.setattr(login, "login_required", _noop)
monkeypatch.setattr(wraps, "setup_required", _noop)
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
# Bypass billing-related decorators used by other endpoints in this module.
monkeypatch.setattr(wraps, "cloud_edition_billing_resource_check", lambda *_args, **_kwargs: (lambda f: f))
monkeypatch.setattr(wraps, "cloud_edition_billing_rate_limit_check", lambda *_args, **_kwargs: (lambda f: f))
# Avoid Flask-RESTX route registration side effects during import.
def _noop_route(*_args, **_kwargs): # type: ignore[override]
def _decorator(cls):
return cls
return _decorator
monkeypatch.setattr(console_ns, "route", _noop_route)
module_name = "controllers.console.datasets.datasets_document"
sys.modules.pop(module_name, None)
return importlib.import_module(module_name)
def _mock_user(*, is_dataset_editor: bool = True) -> SimpleNamespace:
"""Build a minimal user object compatible with dataset permission checks."""
return SimpleNamespace(is_dataset_editor=is_dataset_editor, id="user-123")
def _mock_document(
*,
document_id: str,
tenant_id: str,
data_source_type: str,
upload_file_id: str | None,
) -> SimpleNamespace:
"""Build a minimal document object used by the controller."""
data_source_info_dict: dict[str, Any] | None = None
if upload_file_id is not None:
data_source_info_dict = {"upload_file_id": upload_file_id}
else:
data_source_info_dict = {}
return SimpleNamespace(
id=document_id,
tenant_id=tenant_id,
data_source_type=data_source_type,
data_source_info_dict=data_source_info_dict,
)
def _wire_common_success_mocks(
*,
module,
monkeypatch: pytest.MonkeyPatch,
current_tenant_id: str,
document_tenant_id: str,
data_source_type: str,
upload_file_id: str | None,
upload_file_exists: bool,
signed_url: str,
) -> None:
"""Patch controller dependencies to create a deterministic test environment."""
import services.dataset_service as dataset_service_module
# Make `current_account_with_tenant()` return a known user + tenant id.
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (_mock_user(), current_tenant_id))
# Return a dataset object and allow permission checks to pass.
monkeypatch.setattr(module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1"))
monkeypatch.setattr(module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None)
# Return a document that will be validated inside DocumentResource.get_document.
document = _mock_document(
document_id="doc-1",
tenant_id=document_tenant_id,
data_source_type=data_source_type,
upload_file_id=upload_file_id,
)
monkeypatch.setattr(module.DocumentService, "get_document", lambda *_args, **_kwargs: document)
# Mock UploadFile lookup via FileService batch helper.
upload_files_by_id: dict[str, Any] = {}
if upload_file_exists and upload_file_id is not None:
upload_files_by_id[str(upload_file_id)] = SimpleNamespace(id=str(upload_file_id))
monkeypatch.setattr(module.FileService, "get_upload_files_by_ids", lambda *_args, **_kwargs: upload_files_by_id)
# Mock signing helper so the returned URL is deterministic.
monkeypatch.setattr(dataset_service_module.file_helpers, "get_signed_file_url", lambda **_kwargs: signed_url)
def _mock_send_file(obj, **kwargs): # type: ignore[no-untyped-def]
"""Return a lightweight representation of `send_file(...)` for unit tests."""
class _ResponseMock(UserDict):
def __init__(self, sent_file: object, send_file_kwargs: dict[str, object]) -> None:
super().__init__({"_sent_file": sent_file, "_send_file_kwargs": send_file_kwargs})
self._on_close: object | None = None
def call_on_close(self, func): # type: ignore[no-untyped-def]
self._on_close = func
return func
return _ResponseMock(obj, kwargs)
def test_batch_download_zip_returns_send_file(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure batch ZIP download returns a zip attachment via `send_file`."""
# Arrange common permission mocks.
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
monkeypatch.setattr(
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
)
monkeypatch.setattr(
datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None
)
# Two upload-file documents, each referencing an UploadFile.
doc1 = _mock_document(
document_id="11111111-1111-1111-1111-111111111111",
tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-1",
)
doc2 = _mock_document(
document_id="22222222-2222-2222-2222-222222222222",
tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-2",
)
monkeypatch.setattr(
datasets_document_module.DocumentService,
"get_documents_by_ids",
lambda *_args, **_kwargs: [doc1, doc2],
)
monkeypatch.setattr(
datasets_document_module.FileService,
"get_upload_files_by_ids",
lambda *_args, **_kwargs: {
"file-1": SimpleNamespace(id="file-1", name="a.txt", key="k1"),
"file-2": SimpleNamespace(id="file-2", name="b.txt", key="k2"),
},
)
# Mock storage streaming content.
import services.file_service as file_service_module
monkeypatch.setattr(file_service_module.storage, "load", lambda _key, stream=True: [b"hello"])
# Replace send_file used by the controller to avoid a real Flask response object.
monkeypatch.setattr(datasets_document_module, "send_file", _mock_send_file)
# Act
with app.test_request_context(
"/datasets/ds-1/documents/download-zip",
method="POST",
json={"document_ids": ["11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222"]},
):
api = datasets_document_module.DocumentBatchDownloadZipApi()
result = api.post(dataset_id="ds-1")
# Assert: we returned via send_file with correct mime type and attachment.
assert result["_send_file_kwargs"]["mimetype"] == "application/zip"
assert result["_send_file_kwargs"]["as_attachment"] is True
assert isinstance(result["_send_file_kwargs"]["download_name"], str)
assert result["_send_file_kwargs"]["download_name"].endswith(".zip")
# Ensure our cleanup hook is registered and execute it to avoid temp file leaks in unit tests.
assert getattr(result, "_on_close", None) is not None
result._on_close() # type: ignore[attr-defined]
def test_batch_download_zip_response_is_openable_zip(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure the real Flask `send_file` response body is a valid ZIP that can be opened."""
# Arrange: same controller mocks as the lightweight send_file test, but we keep the real `send_file`.
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
monkeypatch.setattr(
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
)
monkeypatch.setattr(
datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None
)
doc1 = _mock_document(
document_id="33333333-3333-3333-3333-333333333333",
tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-1",
)
doc2 = _mock_document(
document_id="44444444-4444-4444-4444-444444444444",
tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-2",
)
monkeypatch.setattr(
datasets_document_module.DocumentService,
"get_documents_by_ids",
lambda *_args, **_kwargs: [doc1, doc2],
)
monkeypatch.setattr(
datasets_document_module.FileService,
"get_upload_files_by_ids",
lambda *_args, **_kwargs: {
"file-1": SimpleNamespace(id="file-1", name="a.txt", key="k1"),
"file-2": SimpleNamespace(id="file-2", name="b.txt", key="k2"),
},
)
# Stream distinct bytes per key so we can verify both ZIP entries.
import services.file_service as file_service_module
monkeypatch.setattr(
file_service_module.storage, "load", lambda key, stream=True: [b"one"] if key == "k1" else [b"two"]
)
# Act
with app.test_request_context(
"/datasets/ds-1/documents/download-zip",
method="POST",
json={"document_ids": ["33333333-3333-3333-3333-333333333333", "44444444-4444-4444-4444-444444444444"]},
):
api = datasets_document_module.DocumentBatchDownloadZipApi()
response = api.post(dataset_id="ds-1")
# Assert: response body is a valid ZIP and contains the expected entries.
response.direct_passthrough = False
data = response.get_data()
response.close()
with ZipFile(BytesIO(data), mode="r") as zf:
assert zf.namelist() == ["a.txt", "b.txt"]
assert zf.read("a.txt") == b"one"
assert zf.read("b.txt") == b"two"
def test_batch_download_zip_rejects_non_upload_file_document(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure batch ZIP download rejects non upload-file documents."""
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
monkeypatch.setattr(
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
)
monkeypatch.setattr(
datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None
)
doc = _mock_document(
document_id="55555555-5555-5555-5555-555555555555",
tenant_id="tenant-123",
data_source_type="website_crawl",
upload_file_id="file-1",
)
monkeypatch.setattr(
datasets_document_module.DocumentService,
"get_documents_by_ids",
lambda *_args, **_kwargs: [doc],
)
with app.test_request_context(
"/datasets/ds-1/documents/download-zip",
method="POST",
json={"document_ids": ["55555555-5555-5555-5555-555555555555"]},
):
api = datasets_document_module.DocumentBatchDownloadZipApi()
with pytest.raises(NotFound):
api.post(dataset_id="ds-1")
def test_document_download_returns_url_for_upload_file_document(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure upload-file documents return a `{url}` JSON payload."""
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-123",
upload_file_exists=True,
signed_url="https://example.com/signed",
)
# Build a request context then call the resource method directly.
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
result = api.get(dataset_id="ds-1", document_id="doc-1")
assert result == {"url": "https://example.com/signed"}
def test_document_download_rejects_non_upload_file_document(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure non-upload documents raise 404 (no file to download)."""
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="website_crawl",
upload_file_id="file-123",
upload_file_exists=True,
signed_url="https://example.com/signed",
)
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
with pytest.raises(NotFound):
api.get(dataset_id="ds-1", document_id="doc-1")
def test_document_download_rejects_missing_upload_file_id(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure missing `upload_file_id` raises 404."""
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id=None,
upload_file_exists=False,
signed_url="https://example.com/signed",
)
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
with pytest.raises(NotFound):
api.get(dataset_id="ds-1", document_id="doc-1")
def test_document_download_rejects_when_upload_file_record_missing(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure missing UploadFile row raises 404."""
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-123",
upload_file_exists=False,
signed_url="https://example.com/signed",
)
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
with pytest.raises(NotFound):
api.get(dataset_id="ds-1", document_id="doc-1")
def test_document_download_rejects_tenant_mismatch(
app: Flask, datasets_document_module, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure tenant mismatch is rejected by the shared `get_document()` permission check."""
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-999",
data_source_type="upload_file",
upload_file_id="file-123",
upload_file_exists=True,
signed_url="https://example.com/signed",
)
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
with pytest.raises(Forbidden):
api.get(dataset_id="ds-1", document_id="doc-1")

View File

@@ -1,182 +0,0 @@
"""Tests for file_manager module, specifically multimodal content handling."""
from unittest.mock import patch
from core.file import File, FileTransferMethod, FileType
from core.file.file_manager import (
_encode_file_ref,
restore_multimodal_content,
to_prompt_message_content,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
class TestEncodeFileRef:
"""Tests for _encode_file_ref function."""
def test_encodes_local_file(self):
"""Local file should be encoded as 'local:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="abc123",
storage_key="key",
)
assert _encode_file_ref(file) == "local:abc123"
def test_encodes_tool_file(self):
"""Tool file should be encoded as 'tool:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="xyz789",
storage_key="key",
)
assert _encode_file_ref(file) == "tool:xyz789"
def test_encodes_remote_url(self):
"""Remote URL should be encoded as 'remote:url'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.png",
storage_key="",
)
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
class TestToPromptMessageContent:
"""Tests for to_prompt_message_content function with file_ref field."""
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._get_encoded_string")
def test_includes_file_ref(self, mock_get_encoded, mock_config):
"""Generated content should include file_ref field."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_get_encoded.return_value = "base64data"
file = File(
id="test-message-file-id",
tenant_id="test-tenant",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-related-id",
remote_url=None,
extension=".png",
mime_type="image/png",
filename="test.png",
storage_key="test-key",
)
result = to_prompt_message_content(file)
assert isinstance(result, ImagePromptMessageContent)
assert result.file_ref == "local:test-related-id"
assert result.base64_data == "base64data"
class TestRestoreMultimodalContent:
"""Tests for restore_multimodal_content function."""
def test_returns_content_unchanged_when_no_file_ref(self):
"""Content without file_ref should pass through unchanged."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref=None,
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_data(self):
"""Content that already has base64_data should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_url(self):
"""Content that already has url should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://example.com/image.png"
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._to_url")
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
"""Content should be restored from file_ref when url is empty (url mode)."""
mock_config.MULTIMODAL_SEND_FORMAT = "url"
mock_build_file.return_value = "mock_file"
mock_to_url.return_value = "https://restored-url.com/image.png"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://restored-url.com/image.png"
mock_build_file.assert_called_once()
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._get_encoded_string")
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
"""Content should be restored as base64 when in base64 mode."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_build_file.return_value = "mock_file"
mock_get_encoded.return_value = "restored-base64-data"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "restored-base64-data"
mock_build_file.assert_called_once()
def test_handles_invalid_file_ref_gracefully(self):
"""Invalid file_ref format should be handled gracefully."""
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
file_ref="invalid_format_no_colon",
)
result = restore_multimodal_content(content)
# Should return unchanged on error
assert result.base64_data == ""

View File

@@ -1,269 +0,0 @@
"""
Unit tests for file reference detection and conversion.
"""
import uuid
from unittest.mock import MagicMock, patch
import pytest
from core.file import File, FileTransferMethod, FileType
from core.llm_generator.output_parser.file_ref import (
FILE_REF_FORMAT,
convert_file_refs_in_output,
detect_file_ref_fields,
is_file_ref_property,
)
from core.variables.segments import ArrayFileSegment, FileSegment
class TestIsFileRefProperty:
"""Tests for is_file_ref_property function."""
def test_valid_file_ref(self):
schema = {"type": "string", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is True
def test_invalid_type(self):
schema = {"type": "number", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is False
def test_missing_format(self):
schema = {"type": "string"}
assert is_file_ref_property(schema) is False
def test_wrong_format(self):
schema = {"type": "string", "format": "uuid"}
assert is_file_ref_property(schema) is False
class TestDetectFileRefFields:
"""Tests for detect_file_ref_fields function."""
def test_simple_file_ref(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["image"]
def test_multiple_file_refs(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
"document": {"type": "string", "format": FILE_REF_FORMAT},
"name": {"type": "string"},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "document"}
def test_array_of_file_refs(self):
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["files[*]"]
def test_nested_file_ref(self):
schema = {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["data.image"]
def test_no_file_refs(self):
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_empty_schema(self):
schema = {}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_mixed_schema(self):
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"documents": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "documents[*]"}
class TestConvertFileRefsInOutput:
"""Tests for convert_file_refs_in_output function."""
@pytest.fixture
def mock_file(self):
"""Create a mock File object with all required attributes."""
file = MagicMock(spec=File)
file.type = FileType.IMAGE
file.transfer_method = FileTransferMethod.TOOL_FILE
file.related_id = "test-related-id"
file.remote_url = None
file.tenant_id = "tenant_123"
file.id = None
file.filename = "test.png"
file.extension = ".png"
file.mime_type = "image/png"
file.size = 1024
file.dify_model_identity = "__dify__file__"
return file
@pytest.fixture
def mock_build_from_mapping(self, mock_file):
"""Mock the build_from_mapping function."""
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.return_value = mock_file
yield mock
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
# Result should be wrapped in FileSegment
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
mock_build_from_mapping.assert_called_once_with(
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
tenant_id="tenant_123",
)
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
output = {"files": [file_id1, file_id2]}
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
# Result should be wrapped in ArrayFileSegment
assert isinstance(result["files"], ArrayFileSegment)
assert list(result["files"].value) == [mock_file, mock_file]
assert mock_build_from_mapping.call_count == 2
def test_no_conversion_without_file_refs(self):
output = {"name": "test", "count": 5}
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result == {"name": "test", "count": 5}
def test_invalid_uuid_returns_none(self):
output = {"image": "not-a-valid-uuid"}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_file_not_found_returns_none(self):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.side_effect = ValueError("File not found")
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"query": "search term", "image": file_id, "count": 10}
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["query"] == "search term"
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
assert result["count"] == 10
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
original = {"image": file_id}
output = dict(original)
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
convert_file_refs_in_output(output, schema, "tenant_123")
# Original should still contain the string ID
assert original["image"] == file_id

View File

@@ -346,6 +346,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeRateLimitError",
"message": "Rate limit exceeded",
"args": {"description": "Rate limit exceeded"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -364,6 +365,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeAuthorizationError",
"message": "Invalid credentials",
"args": {"description": "Invalid credentials"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -382,6 +384,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeBadRequestError",
"message": "Invalid parameters",
"args": {"description": "Invalid parameters"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -400,6 +403,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeConnectionError",
"message": "Connection to external service failed",
"args": {"description": "Connection to external service failed"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
@@ -418,6 +422,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
invoke_error = {
"error_type": "InvokeServerUnavailableError",
"message": "Service temporarily unavailable",
"args": {"description": "Service temporarily unavailable"},
}
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})

Some files were not shown because too many files have changed in this diff Show More