Compare commits

..

172 Commits

Author SHA1 Message Date
Harry
2fa6684c4d Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-13 10:59:57 +08:00
-LAN-
fe6538b08d chore: disable workflow logs auto-cleanup by default (#28136)
This PR changes the default value of `WORKFLOW_LOG_CLEANUP_ENABLED` from `true` to `false` across all configuration files.

## Motivation

Setting the default to `false` provides safer default behavior by:

- Preventing unintended data loss for new installations
- Giving users explicit control over when to enable log cleanup
- Following the opt-in principle for data deletion features

Users who need automatic cleanup can enable it by setting `WORKFLOW_LOG_CLEANUP_ENABLED=true` in their configuration.
2025-11-12 22:55:02 +08:00
Asuka Minato
1bbb9d6644 convert to TypeBase (#27935)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-12 21:50:13 +08:00
Gritty_dev
5c06e285ec test: create some hooks and utils test script, modified clipboard test script (#27928) 2025-11-12 21:47:06 +08:00
Gen Sato
19c92fd670 Add file type validation to paste upload (#28017) 2025-11-12 19:27:56 +08:00
非法操作
6026bd873b fix: variable assigner can't assign float number (#28068) 2025-11-12 19:27:36 +08:00
Bowen Liang
1369119a0c fix: determine cpu cores determination in baseedpyright-check script on macos (#28058) 2025-11-12 19:27:27 +08:00
Yeuoly
b76e17b25d feat: introduce trigger functionality (#27644)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: Stream <Stream_2@qq.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
Co-authored-by: Harry <xh001x@hotmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
Co-authored-by: hjlarry <hjlarry@163.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WTW0313 <twwu@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-12 17:59:37 +08:00
Yeuoly
87439b8fec Merge main 2025-11-12 17:49:17 +08:00
Yeuoly
08034532f6 Update trigger.py 2025-11-12 17:42:55 +08:00
Maries
c5f47ebccd Merge branch 'main' into feat/trigger 2025-11-12 17:14:35 +08:00
Jyong
ca7794305b add transform-datasource-credentials command online check (#28124)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
2025-11-12 17:13:44 +08:00
Yeuoly
6744306818 Merge branch 'main' into feat/trigger 2025-11-12 17:04:31 +08:00
QuantumGhost
fd255e81e1 feat(api): Introduce WorkflowResumptionContext for pause state management (#28122)
Certain metadata (including but not limited to `InvokeFrom`, `call_depth`, and `streaming`)  is required when resuming a paused workflow. However, these fields are not part of `GraphRuntimeState` and were not saved in the previous
 implementation of  `PauseStatePersistenceLayer`.

This commit addresses this limitation by introducing a `WorkflowResumptionContext` model that wraps both the `*GenerateEntity` and `GraphRuntimeState`. This approach provides:

- A structured container for all necessary resumption data
- Better separation of concerns between execution state and persistence
- Enhanced extensibility for future metadata additions
- Clearer naming that distinguishes from `GraphRuntimeState`

The `WorkflowResumptionContext` model makes extending the pause state easier while maintaining backward compatibility and proper version management for the entire execution state ecosystem.

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-12 17:00:02 +08:00
lyzno1
3ff14ccc89 Merge branch 'main' into feat/trigger 2025-11-12 16:49:08 +08:00
Joel
09d31d1263 chore: improve the user experience of not login into apps (#28120) 2025-11-12 16:47:45 +08:00
Yeuoly
9c30f16e4b Update api/tasks/trigger_processing_tasks.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-12 16:42:12 +08:00
Harry
c31933c163 chore: downgrade dify-api version to 1.9.2 in uv.lock 2025-11-12 16:20:22 +08:00
Harry
574eb1a10a chore: downgrade version to 1.9.2 in pyproject.toml(ready for merge) 2025-11-12 16:18:19 +08:00
Harry
e6ac783fc3 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-12 16:01:19 +08:00
Jyong
47dc26f011 fix document index test (#28113) 2025-11-12 16:00:10 +08:00
Harry
689a75f44a Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-12 15:07:05 +08:00
湛露先生
123bb3ec08 When graph_engine worker run exception, keep the node_id for deep res… (#26205)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-11-12 15:03:45 +08:00
Joel
90f77282e3 chore: not SaaS version can query long log time range (#28109) 2025-11-12 14:45:56 +08:00
hjlarry
a25f469bde fix(trigger): Prevent marketplace tool downloads from retriggering on tab change 2025-11-12 13:40:56 +08:00
lyzno1
044ee7ef54 feat(workflow): always render featured tools section 2025-11-12 10:37:39 +08:00
hjlarry
2725f28fa8 fix(trigger): incorrect behavior when node uninstalled on the canvas 2025-11-12 10:12:49 +08:00
zhsama
36ad784251 feat(workflow-header): add conditional logic to disable publish and refresh actions based on workflow node presence 2025-11-11 20:08:09 +08:00
lyzno1
0a39e5c092 Fix modal query sync for settings & pricing 2025-11-11 19:32:48 +08:00
lyzno1
9169a5e35b Merge branch 'main' into feat/trigger 2025-11-11 18:05:23 +08:00
Jyong
5208867ccc fix document enable (#28081) 2025-11-11 17:50:45 +08:00
zhsama
bfdcb79e19 feat(card-view): enhance CardView to conditionally render AppCards based on trigger node presence in workflow 2025-11-11 16:54:06 +08:00
zhsama
c37cce000f refactor: replace TRIGGER_NODE_TYPES with isTriggerNode utility for improved node type checks across workflow components 2025-11-11 16:54:06 +08:00
autofix-ci[bot]
6d3fb9b769 [autofix.ci] apply automated fixes 2025-11-11 08:37:13 +00:00
Harry
c04913ecf8 refactor(tests): remove redundant graph validation tests from WorkflowService unit tests
- Deleted tests for graph initialization and error propagation that were deemed unnecessary.
- Cleaned up the test suite to improve maintainability and focus on essential validation scenarios.
2025-11-11 16:35:19 +08:00
lyzno1
3a84a64c32 refactor: unify account setting tab constants and tighten modal types 2025-11-11 16:16:41 +08:00
lyzno1
b344d4add1 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-11 16:16:22 +08:00
lyzno1
405a4ec9f8 feat: add URL parameter support for settings modal using action=showSettings 2025-11-11 16:16:06 +08:00
lyzno1
edc7ccc795 chore: add type-check to pre-commit (#28005) 2025-11-11 16:14:39 +08:00
Harry
8bb11a588c fix(tests): fix end node missing ouputs 2025-11-11 16:14:00 +08:00
Harry
44f451bd7d refactor(api): improve graph validation logic in WorkflowService
- Updated the validate_graph_structure method to handle empty graph cases gracefully.
- Introduced a variable for workflow_id to ensure consistent handling of unknown workflow IDs.
- Enhanced code readability and maintainability by refining the method's structure.
2025-11-11 16:14:00 +08:00
zhsama
35d914e755 Merge remote-tracking branch 'origin/feat/trigger' into feat/trigger 2025-11-11 15:22:39 +08:00
zhsama
9c37f8c1cb feat(trigger): add support for trigger nodes and user input node management in workflow components 2025-11-11 15:21:04 +08:00
lyzno1
9de0e3c3a7 fix: add missing TimePicker type definitions for notClearable, triggerFullWidth, showTimezone and placement 2025-11-11 15:18:14 +08:00
lyzno1
707c94f86e feat: add URL parameter support for pricing modal using action=showPricing 2025-11-11 15:15:40 +08:00
lyzno1
81afd087f6 feat: add trigger events, workflow execution, and start nodes to billing plan features
- Add three new feature items to cloud plan list:
  - Trigger Events (varies by plan: 3K for sandbox, 20K/month for pro, unlimited for team)
  - Workflow Execution (standard/faster/priority based on plan)
  - Start Nodes (limited to 2 for sandbox, unlimited for pro/team)
- Add i18n translations for en-US and zh-Hans
- Position new items below document processing priority and above divider
2025-11-11 15:00:25 +08:00
lyzno1
0f952f328f feat: add api rate limit and trigger events billing card 2025-11-11 15:00:25 +08:00
autofix-ci[bot]
50619fba0a [autofix.ci] apply automated fixes 2025-11-11 06:54:03 +00:00
Harry
aad31bb703 feat(api): enhance workflow validation and structure checks
- Added a new validation class to ensure that trigger nodes do not coexist with UserInput (start) nodes in the workflow graph.
- Implemented a method in WorkflowService to validate the graph structure before persisting workflows, leveraging the new validation logic.
- Updated unit tests to cover the new validation scenarios and ensure proper error propagation.
2025-11-11 14:52:13 +08:00
hjlarry
7484a020e1 fix(trigger): subscription schema use bool field cause pydantic error 2025-11-11 14:05:11 +08:00
autofix-ci[bot]
186828c13a [autofix.ci] apply automated fixes 2025-11-11 04:47:22 +00:00
Harry
203fb95391 chore(api): update dependencies and default queue configurations
- Updated `revision` in `uv.lock` from 3 to 2.
- Added `croniter` package version 6.0.0 with dependencies in `uv.lock`.
- Updated `dify-api` version to 1.10.0rc1 and added `croniter` as a dependency.
- Modified default queue names in `entrypoint.sh` for both CLOUD and SELF_HOSTED editions to include `priority_dataset`.
2025-11-11 12:45:02 +08:00
Harry
a94e650ffd Merge remote-tracking branch 'origin/main' into feat/trigger
# Conflicts:
#	api/docker/entrypoint.sh
#	api/uv.lock
#	dev/start-worker
#	docker/.env.example
#	docker/docker-compose.yaml
#	web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx
#	web/app/components/base/date-and-time-picker/date-picker/index.tsx
#	web/app/components/base/date-and-time-picker/types.ts
2025-11-11 12:42:01 +08:00
Ali Saleh
c9798f6425 fix(api): Trace Hierarchy, Span Status, and Broken Workflow for Arize & Phoenix Integration (#27937)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-11 11:49:19 +08:00
crazywoola
20ecf7f1d0 chore: remove unused enterprise bot from the readme (#28073)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-11 10:52:27 +08:00
github-actions[bot]
9dcb780fcb chore: translate i18n files and update type definitions (#28054)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
2025-11-11 09:32:53 +08:00
Will
1cb7b09933 chore: Remove trailing space from migration filename (#28040) 2025-11-11 09:32:42 +08:00
Joel
2c62a77cf4 Chore: change query log time range (#28052)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-10 18:39:12 +08:00
QuantumGhost
b9bc48d8dd feat(api): Introduce Broadcast Channel (#27835)
This PR introduces a `BroadcastChannel` abstraction with broadcasting and at-most once delivery semantics, serving as the communication component between celery worker and API server.

It also includes a reference implementation backed by Redis PubSub.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-10 17:23:21 +08:00
Will
ed234e311b fix workflow default updated_at (#28047) 2025-11-10 18:20:38 +09:00
hjlarry
00fdd06179 fix(trigger): subscription schema config not display field description 2025-11-10 13:43:56 +08:00
huangzhuo1949
9843fec393 fix: elasticsearch_vector version (#28028)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-10 13:17:13 +09:00
lyzno1
62fbc90389 refactor: update free plan rate limit description in pricing modal 2025-11-10 10:48:32 +08:00
Will
aa4cabdeb5 feat: Add Audio Content Support for MCP Tools (#27979) 2025-11-10 10:12:11 +08:00
NeatGuyCoding
eea713b668 Fix typo in weaviate comment, improve time test precision, and add security tests for get-icon utility (#27919)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-10 10:11:54 +08:00
hjlarry
f19a21da11 fix prepare userinput logic 2025-11-10 09:59:20 +08:00
dependabot[bot]
fc62538a94 chore(deps): bump scipy-stubs from 1.16.2.3 to 1.16.3.0 in /api (#28025)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-10 09:54:56 +08:00
Asuka Minato
7994144df7 add onupdate=func.current_timestamp() (#28014)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-10 01:48:52 +09:00
Kenn
e153c483b6 fix: the model list encountered two children with the same key (#27956)
Co-authored-by: haokai <haokai@shuwen.com>
2025-11-09 21:39:59 +08:00
wangxiaolei
422bb4d4bb fix: fix https://github.com/langgenius/dify/issues/27939 (#27985) 2025-11-09 21:39:05 +08:00
OneZero-Y
87a80d7613 docs: clarify how to obtain workflow_id for version execution (#28007)
Signed-off-by: OneZero-Y <aukovyps@163.com>
2025-11-09 21:38:06 +08:00
zhsama
7401792063 feat(last-run): add handling for Listening status in run result calculation 2025-11-07 16:39:30 +08:00
kenwoodjw
e91105ca87 fix: bump brotli to 1.2.0 resloved CVE-2025-6176 (#27950)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-11-07 15:57:29 +08:00
zhsama
79e46c8a81 feat(step-run): add resolvedStatus calculation for improved run result handling 2025-11-07 15:19:40 +08:00
zhsama
7658c92cf9 feat(trigger): improve trigger node in useOneStepRun for getting system variables 2025-11-07 13:17:57 +08:00
hj24
37903722fe refactor: implement tenant self queue for rag tasks (#27559)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-11-06 21:25:50 +08:00
QuantumGhost
f4c82d0010 fix(api): fix VariablePool.get adding unexpected keys to variable_dictionary (#26767)
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-06 18:30:35 +08:00
NFish
fe50093c18 fix: prevent fetch version info in enterprise edition (#27923) 2025-11-06 17:59:53 +08:00
Jyong
4317af1e90 fix jina reader transform (#27922) 2025-11-06 17:35:53 +08:00
zhsama
85a5c78b80 feat: enhance workflow log components with detailed trigger metadata and type safety 2025-11-06 16:16:34 +08:00
lyzno1
9d7b47c784 trigger CI
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
2025-11-06 15:41:54 +08:00
lyzno1
e4c6ed9c60 Merge branch 'main' into feat/trigger 2025-11-06 15:25:19 +08:00
zhsama
fcfade4778 fix: update checkValidFns type to Partial and ensure error message fallback in useOneStepRun 2025-11-06 14:51:32 +08:00
zhsama
000e8bd12b fix: improve error handling in useOneStepRun and useWorkflowRun to provide structured error messages 2025-11-06 14:05:45 +08:00
Harry
ed8da2c760 fix: return structured error response for PluginInvokeError in workflow trigger APIs 2025-11-06 12:41:17 +08:00
zhsama
fb6dc14e9b refactor:simplify syncWorkflowDraft parameters 2025-11-06 12:19:09 +08:00
hjlarry
77e6e98234 fix CI 2025-11-06 09:46:43 +08:00
red_sun
61a0fcc2ea fix agent putout the output of workflow-tool twice (#26835) (#27087) 2025-11-06 09:41:05 +08:00
lyzno1
fb3699ec5e fix(web): resolve type checks in app operations 2025-11-05 20:08:59 +08:00
Jyong
f627348b11 fix jina reader creadential migration command (#27883) 2025-11-05 18:42:07 +08:00
zhsama
4601be8b67 fix: update hasConnectedUserInput function to use specific types for nodes and edges 2025-11-05 18:12:57 +08:00
Cursx
87fb9a6b69 fix Version 2.0.0-beta.2: Chat annotations Api Error #25506 (#27206)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2025-11-05 17:37:19 +08:00
Yongtao Huang
97a2e2ec2e Fix: correct DraftWorkflowApi.post response model (#27289)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-05 17:20:40 +08:00
Boris Polonsky
68d357d7f6 Add WEAVIATE_GRPC_ENDPOINT as designed in weaviate migration guide (#27861)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-05 17:19:08 +08:00
zhsama
cc4d4adfb9 feat: enhance workflow draft processing by adding hydration and sanitization functions 2025-11-05 16:54:49 +08:00
lyzno1
6a08623949 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-05 16:50:53 +08:00
Harry
f0127ffc9a feat: enhance trigger metadata structure by adding type field and updating trigger_metadata handling 2025-11-05 16:49:29 +08:00
Harry
f1e513830c feat: implement logging for failed trigger invocations in workflow processing 2025-11-05 16:49:29 +08:00
Harry
7de533a643 feat: add start-web script to automate web project setup and execution 2025-11-05 16:49:29 +08:00
crazywoola
a103ad3ee7 bump vite to 6.4.1 (#27877) 2025-11-05 16:33:19 +08:00
wangjifeng
f65d5a9761 Fix/template transformer line number (#27867)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-05 15:21:47 +08:00
github-actions[bot]
6e0a5f5bbd chore: translate i18n files and update type definitions (#27868)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-05 15:17:53 +08:00
crazywoola
22f858152f feat: change feedback to forum (#27862) 2025-11-05 14:51:57 +08:00
autofix-ci[bot]
052127c473 [autofix.ci] apply automated fixes 2025-11-05 05:07:06 +00:00
Harry
7a4be5c0d2 Update package metadata in uv.lock: increment revision to 2, add upload times for sdist and wheels, and update aiohttp sdist version to 3.13.2. 2025-11-05 13:00:28 +08:00
Harry
a6208feed8 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-05 12:59:25 +08:00
Gritty_dev
775d2e14fc test: create new test scripts and update some existing test scripts o… (#27850) 2025-11-05 11:09:24 +08:00
lyzno1
c8f55549d7 Remove global pnpm installation from script 2025-11-05 10:27:45 +08:00
johnny0120
744b287e67 fix: avoid passing empty uniqueIdentifier to InstallFromMarketplace (#27802)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-05 10:22:22 +08:00
crazywoola
c0fc5d98f0 fix: installation_id is missing when in tools page (#27849) 2025-11-05 10:19:12 +08:00
lyzno1
132a86dcb3 fix: output node vars check 2025-11-05 10:15:57 +08:00
Elliott
08ea79d730 fix(web): increase z-index of PortalToFollowElemContent (#27823) 2025-11-05 09:32:15 +08:00
yangzheli
f31b821cc0 fix(web): improve the consistency of the inputs-form UI (#27837) 2025-11-05 09:29:13 +08:00
Novice
34be16874f feat: add validation to prevent saving empty opening statement in conversation opener modal (#27843) 2025-11-05 09:28:49 +08:00
aka James4u
e9738b891f test: adding some web tests (#27792) 2025-11-04 21:06:44 +08:00
zhsama
9f59baed10 fix(urlValidation): remove specific check for Dify cloud trigger debug URLs 2025-11-04 18:34:41 +08:00
zhsama
ce56286329 feat(validation): implement isPrivateOrLocalAddress utility and integrate into webhook components for improved URL validation 2025-11-04 18:32:19 +08:00
Harry
6e76f2aff2 fix(api): update TriggerInvokeEventResponse to use Field for default value of cancelled 2025-11-04 18:15:13 +08:00
Harry
49edd58722 fix(trigger): enhance credential handling by decrypting and masking subscription properties and parameters 2025-11-04 18:15:13 +08:00
zhsama
6a28aee13e Merge remote-tracking branch 'origin/feat/trigger' into feat/trigger 2025-11-04 16:51:37 +08:00
zhengchangchun
829796514a fix:knowledge base reference information is overwritten when using mu… (#27799)
Co-authored-by: zhengchangchun <zhengchangchun@corp.netease.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-04 16:40:44 +08:00
WTW0313
79c70d09c9 feat(marketplace): introduce IS_MARKETPLACE flag and update X-Dify-Version header logic 2025-11-04 16:09:50 +08:00
Novice
ef1db35f80 feat: implement file extension blacklist for upload security (#27540) 2025-11-04 15:45:22 +08:00
zhsama
b9bb97887b fix(workflow): handle node inspection variable deletion when not fetched 2025-11-04 15:25:15 +08:00
Cursx
f9c67621ca fix agent putout the output of workflow-tool twice (#26835) (#27706) 2025-11-04 14:24:51 +08:00
Guangdong Liu
e29e8e3180 feat: enhance annotation API to support optional message_id and content fields (#27460)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-04 14:11:09 +08:00
red_sun
7a81e720d4 fix: iteration node cannot be viewed(#27759) (#27786) 2025-11-04 12:37:31 +08:00
XlKsyt
55600c0eb1 feat: add metrics logging and improve MeterProvider lifecycle for tencent APM (#27733)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-04 12:35:53 +08:00
kenwoodjw
35e41d7d68 fix: bump pyobvector to 0.2.17 (#27791)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-11-04 12:25:50 +08:00
Ponder
b610cf9a11 feat: add segments max number limit for SegmentApi.post (#27745) 2025-11-04 10:27:58 +08:00
-LAN-
c8e9edc024 refactor(api): set default value for EasyUIBasedAppGenerateEntity.query (#27712) 2025-11-04 10:22:43 +08:00
49
471cd760d7 fix: improve infinite scroll observer responsiveness (#27546)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-04 10:15:27 +08:00
墨绿色
7f48c57edf fix: datasets weight settings embedding model does not change (#27694)
Co-authored-by: lijiezhao <lijiezhao@perfect99.com>
2025-11-04 10:00:36 +08:00
NeatGuyCoding
6569801162 extract parse_time_range for console app stats related queries (#27626)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-11-04 10:00:12 +08:00
国昊
9dd83f50a7 FIX Issue #27697: Add env variable in docker-compose(template) and make it take effect. (#27704) 2025-11-04 09:58:59 +08:00
CrabSAMA
59c56b1b0d fix: File model add known extra fields, fix issue about the tool of… (#27607) 2025-11-04 09:57:25 +08:00
lyzno1
7df6d9f1aa Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-04 09:56:02 +08:00
Tianzhi Jin
94cd2de940 fix(api): return timestamp as integer in document api (#27761) 2025-11-04 09:55:47 +08:00
lyzno1
587f83bc34 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-04 09:55:43 +08:00
heyszt
3c23375607 refactor: Use Repository Pattern for Model Layer (#27663) 2025-11-04 09:53:22 +08:00
lyzno1
d81b2e6820 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-04 09:52:08 +08:00
dependabot[bot]
56047f638f chore(deps): bump dayjs from 1.11.18 to 1.11.19 in /web (#27735)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-11-04 09:47:57 +08:00
vicen
9c01d3e775 fix: two web bugs for json-schema-config-modal (#27718) 2025-11-04 09:45:28 +08:00
lyzno1
8315e0c74b Merge remote-tracking branch 'origin/main' into feat/trigger 2025-11-04 09:44:13 +08:00
海狸大師
c85c87f3da fix(i18n/zh-Hant): unify terminology and improve translation consistency (#27717) 2025-11-04 09:42:26 +08:00
-LAN-
eaa02e3d55 Add SQLAlchemy Mapped annotations to MessageFeedback (#27768) 2025-11-04 09:39:59 +08:00
yihong
0219222a60 fix: pin litellm version ignore build issue (#27742)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-11-04 09:39:03 +08:00
yangzheli
dba659b220 fix(web): fix issues with links, Chinese translations, and styling on the logs page (#27669) 2025-11-04 09:38:15 +08:00
Bowen Liang
ee6458768e cleanup orphan packages in packages stage of api dockerfile (#27617) 2025-11-04 09:36:52 +08:00
Shemol
ed3d02dc6d web(markdown): support <think> without trailing newline in preprocessThinkTag (#27776)
Signed-off-by: SherlockShemol <shemol@163.com>
2025-11-04 09:35:54 +08:00
CrabSAMA
95471b1188 fix(ui): fixed the bug about empty placeholder when plugin install successfully (#27780) 2025-11-04 09:35:14 +08:00
aka James4u
6190cfbfd8 feat: localization for hi-IN (#27783)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-04 09:34:41 +08:00
hjlarry
3cc6690356 fix(trigger): global var dialogue_count and timestamp auto converted to string 2025-11-03 17:12:58 +08:00
hjlarry
6507263b28 fix(trigger): timestamp should be number type 2025-11-03 15:40:33 +08:00
hjlarry
8fa0bb48df fix CI 2025-11-03 14:57:01 +08:00
aka James4u
11f2f95103 Added it-IT for italian (#27665) 2025-11-03 11:51:45 +08:00
-LAN-
2abbc14703 refactor: replace hardcoded user plan strings with CloudPlan enum (#27675) 2025-11-03 11:51:09 +08:00
dependabot[bot]
b2b2816ade chore(deps): bump tablestore from 6.2.0 to 6.3.7 in /api (#27736) 2025-11-03 11:50:39 +08:00
hjlarry
637a675681 fix(trigger): workflow checklist not work for knowledge pipeline 2025-11-03 09:42:33 +08:00
-LAN-
4461df1bd9 refactor(api): add SQLAlchemy 2.x Mapped type hints to Message model (#27709)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-01 01:16:07 +08:00
lyzno1
085ada86e6 chore: trigger ci
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
2025-10-31 20:15:14 +08:00
lyzno1
f59d430219 Merge remote-tracking branch 'origin/main' into feat/trigger 2025-10-31 20:01:16 +08:00
lyzno1
c415e5b893 Merge branch 'feat/trigger' of https://github.com/langgenius/dify into feat/trigger 2025-10-31 20:00:51 +08:00
lyzno1
67b6b3612c fix: trigger docs link 2025-10-31 20:00:39 +08:00
katakyo
f7f6b4a8b0 i18n(ja-JP): Use 「公開」 for App Overview “Launch” action label (#27680) 2025-10-31 11:23:38 +08:00
Yeuoly
229b0e190f bump version 2025-10-30 23:21:56 +08:00
autofix-ci[bot]
09d412cf2a [autofix.ci] apply automated fixes 2025-10-30 15:18:20 +00:00
Harry
2842cbf1e1 refactor: update error handling to use BadRequest for plugin invocation errors 2025-10-30 23:16:14 +08:00
Harry
e2543bcf30 refactor: remove unused error imports in TriggerManager 2025-10-30 23:05:08 +08:00
Harry
3f75aa6848 refactor: simplify error handling in TriggerManager 2025-10-30 23:04:54 +08:00
Yeuoly
57719f3ce9 fix: docker env 2025-10-30 22:21:14 +08:00
Harry
45677ac57c bump plugin daemon image version to 0.4.0-local 2025-10-30 22:19:24 +08:00
395 changed files with 15167 additions and 2651 deletions

View File

@@ -1,7 +1,6 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv

View File

@@ -117,7 +117,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.

View File

@@ -374,6 +374,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Comma-separated list of file extensions blocked from upload for security reasons.
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
# Empty by default to allow all file types.
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
@@ -521,7 +527,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true
WORKFLOW_LOG_CLEANUP_ENABLED=false
# Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
@@ -620,3 +626,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0

View File

@@ -15,7 +15,11 @@ FROM base AS packages
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml uv.lock ./
@@ -49,7 +53,9 @@ RUN \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2

View File

@@ -1,7 +1,7 @@
import sys
def is_db_command():
def is_db_command() -> bool:
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False

View File

@@ -1471,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
@click.option(
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
)
def transform_datasource_credentials(environment: str):
"""
Transform datasource credentials
"""
@@ -1482,9 +1485,14 @@ def transform_datasource_credentials():
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
if environment == "online":
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
jina_plugin_unique_identifier = None
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
@@ -1650,7 +1658,7 @@ def transform_datasource_credentials():
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
provider="jinareader",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,

View File

@@ -360,6 +360,31 @@ class FileUploadConfig(BaseSettings):
default=10,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
"Empty by default to allow all file types."
),
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
default="",
)
@computed_field # type: ignore[misc]
@property
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
"""
Parse and return the blacklist as a set of lowercase extensions.
Returns an empty set if no blacklist is configured.
"""
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
return set()
return {
ext.strip().lower().strip(".")
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
if ext.strip()
}
class HttpConfig(BaseSettings):
"""
@@ -949,6 +974,11 @@ class DataSetConfig(BaseSettings):
default=True,
)
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
description="Maximum number of segments for dataset segments API (0 for unlimited)",
default=0,
)
class WorkspaceConfig(BaseSettings):
"""
@@ -1160,7 +1190,7 @@ class AccountConfig(BaseSettings):
class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
@@ -1179,6 +1209,13 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -1205,6 +1242,7 @@ class FeatureConfig(
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,

View File

@@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,

View File

@@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class BlockedFileExtensionError(BaseHTTPException):
error_code = "file_extension_blocked"
description = "The file extension is blocked for security reasons."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."

View File

@@ -16,6 +16,7 @@ from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
@@ -175,8 +176,10 @@ class AnnotationApi(Resource):
api.model(
"CreateAnnotationRequest",
{
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
@@ -193,11 +196,14 @@ class AnnotationApi(Resource):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation
@setup_required

View File

@@ -1,7 +1,5 @@
from datetime import datetime
import pytz
import sqlalchemy as sa
from flask import abort
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
@@ -19,7 +17,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
@@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)

View File

@@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
@@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
@@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@cloud_edition_billing_resource_check("annotation")
@account_initialization_required
@edit_permission_required
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count")

View File

@@ -1,9 +1,7 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@@ -11,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
@@ -56,26 +55,16 @@ WHERE
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -120,8 +109,11 @@ class DailyConversationStatistic(Resource):
)
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
stmt = (
sa.select(
@@ -134,18 +126,10 @@ class DailyConversationStatistic(Resource):
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
stmt = stmt.where(Message.created_at < end_datetime_utc)
stmt = stmt.group_by("date").order_by("date")
@@ -198,26 +182,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -273,26 +248,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -357,26 +323,17 @@ FROM
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -446,26 +403,17 @@ WHERE
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -525,26 +473,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -602,26 +541,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc

View File

@@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
@@ -112,7 +113,18 @@ class DraftWorkflowApi(Resource):
},
)
)
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
@@ -979,11 +991,13 @@ class DraftWorkflowTriggerRunApi(Resource):
event = poller.poll()
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=event.workflow_args,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
@@ -992,7 +1006,7 @@ class DraftWorkflowTriggerRunApi(Resource):
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except PluginInvokeError as e:
raise ValueError(e.to_user_friendly_error())
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@@ -1050,7 +1064,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
)
event = poller.poll()
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 500
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@@ -1074,7 +1088,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{"status": "error", "error": "An unexpected error occurred while running the node."}
), 500
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
@@ -1126,7 +1140,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
node_ids=node_ids,
)
except PluginInvokeError as e:
raise ValueError(e.to_user_friendly_error())
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@@ -1134,10 +1148,12 @@ class DraftWorkflowTriggerRunAllApi(Resource):
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
try:
workflow_args = dict(trigger_debug_event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=trigger_debug_event.workflow_args,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=trigger_debug_event.node_id,
@@ -1151,4 +1167,4 @@ class DraftWorkflowTriggerRunAllApi(Resource):
{
"status": "error",
}
), 500
), 400

View File

@@ -1,7 +1,4 @@
from datetime import datetime
import pytz
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
@@ -9,6 +6,7 @@ from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
@@ -43,23 +41,11 @@ class WorkflowDailyRunsStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_daily_runs_statistics(
tenant_id=app_model.tenant_id,
@@ -100,23 +86,11 @@ class WorkflowDailyTerminalsStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
tenant_id=app_model.tenant_id,
@@ -157,23 +131,11 @@ class WorkflowDailyTokenCostStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
tenant_id=app_model.tenant_id,
@@ -214,23 +176,11 @@ class WorkflowAverageAppInteractionStatistic(Resource):
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
start_date = None
end_date = None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_date = start_datetime_timezone.astimezone(utc_timezone)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_date = end_datetime_timezone.astimezone(utc_timezone)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
tenant_id=app_model.tenant_id,

View File

@@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -16,7 +17,13 @@ class Subscription(Resource):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()

View File

@@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@@ -8,6 +8,7 @@ import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
@@ -83,6 +84,8 @@ class FileApi(Resource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
return upload_file, 201

View File

@@ -21,6 +21,7 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
setup_required,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
@@ -83,7 +84,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}

View File

@@ -10,6 +10,7 @@ from flask import abort, request
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
@@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
abort(
403,
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",

View File

@@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@@ -2,6 +2,7 @@ from flask import request
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)

View File

@@ -13,6 +13,7 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -67,6 +68,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
kwargs["app_model"] = app_model
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
@@ -75,7 +77,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
@@ -90,6 +91,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)
@@ -139,7 +162,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)

View File

@@ -37,8 +37,7 @@ def trigger_endpoint(endpoint_id: str):
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e:
logger.exception("Endpoint processing failed for {endpoint_id}: {e}")
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 500
except Exception as e:
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 400
except Exception:
logger.exception("Webhook processing failed for {endpoint_id}")
return jsonify({"error": "Internal server error"}), 500

View File

@@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)
@@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)

View File

@@ -79,7 +79,7 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: str | None = None,
query: str = "",
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
@@ -105,7 +105,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query or "",
query=query,
files=files,
context=context,
memory=memory,

View File

@@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,

View File

@@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# return batch, dataset, documents
return {
"batch": batch,

View File

@@ -39,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@overload
def generate(
self,
@@ -139,8 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
# start node get inputs
# for trigger debug run, not prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,

View File

@@ -44,6 +44,9 @@ class InvokeFrom(StrEnum):
DEBUGGER = "debugger"
PUBLISHED = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str):
"""
@@ -110,6 +113,11 @@ class AppGenerateEntity(BaseModel):
inputs: Mapping[str, Any]
files: Sequence[File]
# Unique identifier of the user initiating the execution.
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
#
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
user_id: str
# extras
@@ -135,7 +143,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: str | None = None
query: str = ""
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@@ -1,15 +1,64 @@
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from models.model import AppMode
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
# Wrapper types for `WorkflowAppGenerateEntity` and
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
# and correct reconstruction of the entity field during (de)serialization.
class _WorkflowGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
entity: WorkflowAppGenerateEntity
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]
class WorkflowResumptionContext(BaseModel):
"""WorkflowResumptionContext captures all state necessary for resumption."""
version: Literal["1"] = "1"
# Only workflow / chatflow could be paused.
generate_entity: _GenerateEntityUnion
serialized_graph_runtime_state: str
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, value: str) -> Self:
return cls.model_validate_json(value)
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
return self.generate_entity.entity
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(self, session_factory: Engine | sessionmaker[Session], state_owner_user_id: str):
def __init__(
self,
session_factory: Engine | sessionmaker[Session],
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
@@ -19,6 +68,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@@ -49,13 +99,25 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
else:
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=self.graph_runtime_state.dumps(),
state=state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)

View File

@@ -140,7 +140,27 @@ class MessageCycleManager:
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
# Add new unique resources from the event
for resource in event.retriever_resources or []:
if not resource:
continue
is_duplicate = (
resource.dataset_id
and resource.document_id
and (resource.dataset_id, resource.document_id) in existing_ids
)
if not is_duplicate:
merged_resources.append(resource)
for i, resource in enumerate(merged_resources, 1):
resource.position = i
self._task_state.metadata.retriever_resources = merged_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""

View File

@@ -0,0 +1,15 @@
from collections.abc import Sequence
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: Sequence[str]

View File

@@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
# Return composite sort key: (model_type value, model position index)
return (model.model_type.value, position_index)
# Deduplicate
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
# Sort using the composite sort key
return sorted(provider_models, key=get_sort_key)

View File

@@ -74,6 +74,10 @@ class File(BaseModel):
storage_key: str | None = None,
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
url: str | None = None,
# Legacy compatibility fields - explicitly handle known extra fields
tool_file_id: str | None = None,
upload_file_id: str | None = None,
datasource_file_id: str | None = None,
):
super().__init__(
id=id,

View File

@@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
@@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
"""
)
""")
return runner_script

View File

@@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
import json
from base64 import b64decode

View File

@@ -1,21 +1,22 @@
import hashlib
import json
import logging
import os
import traceback
from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
@@ -30,9 +31,10 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@@ -99,22 +101,45 @@ def datetime_to_nanos(dt: datetime | None) -> int:
return int(dt.timestamp() * 1_000_000_000)
def string_to_trace_id128(string: str | None) -> int:
"""
Convert any input string into a stable 128-bit integer trace ID.
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
string_stacktrace = "".join(traceback.format_exception(error))
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
else:
error_message = str(error)
return error_message
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
It's suitable for generating consistent, unique identifiers from strings.
"""
if string is None:
string = ""
hash_object = hashlib.sha256(string.encode())
# Take the first 16 bytes (128 bits) of the hash digest
digest = hash_object.digest()[:16]
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
# Convert to a 128-bit integer
return int.from_bytes(digest, byteorder="big")
if isinstance(error, Exception):
current_span.record_exception(error)
else:
exception_type = error.__class__.__name__
exception_message = str(error)
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
current_span.set_status(Status(StatusCode.OK))
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
return json.dumps(obj, default=str, ensure_ascii=False)
class ArizePhoenixDataTrace(BaseTraceInstance):
@@ -131,9 +156,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@@ -151,7 +179,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
@@ -166,15 +194,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
workflow_metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
@@ -186,31 +208,58 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
# Through workflow_run_id, get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
try:
# Process workflow nodes
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
for node_execution in workflow_node_executions:
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
inputs_value = node_execution.inputs or {}
outputs_value = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
process_data = node_execution.process_data or {}
execution_metadata = node_execution.metadata or {}
node_metadata = {str(k): v for k, v in execution_metadata.items()}
node_metadata = {
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": node_execution.tenant_id,
"app_id": node_execution.app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
}
if node_execution.execution_metadata:
node_metadata.update(json.loads(node_execution.execution_metadata))
node_metadata.update(
{
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN
@@ -223,8 +272,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model:
node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
@@ -236,17 +286,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
else:
span_kind = OpenInferenceSpanKindValues.CHAIN
workflow_span_context = set_span_in_context(workflow_span)
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=workflow_span_context,
)
try:
@@ -260,11 +313,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
outputs = (
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
)
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
if usage_data:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
@@ -275,8 +325,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
if node_execution.status == "failed":
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
@@ -322,34 +380,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
message_span_id = RandomIdGenerator().generate_span_id()
span_context = SpanContext(
trace_id=trace_id,
span_id=message_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
context=root_span_context,
)
try:
if trace_info.error:
message_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
@@ -383,26 +425,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
message_span_context = set_span_in_context(message_span)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
context=message_span_context,
)
try:
if trace_info.error:
llm_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
if trace_info.message_data.error:
set_span_status(llm_span, trace_info.message_data.error)
else:
set_span_status(llm_span)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
if trace_info.error:
set_span_status(message_span, trace_info.error)
else:
set_span_status(message_span)
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
@@ -418,15 +460,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
@@ -445,19 +481,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
@@ -480,15 +511,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
@@ -499,19 +524,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(end_time))
@@ -533,15 +553,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
@@ -554,19 +568,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(end_time))
@@ -580,20 +589,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
trace_id = string_to_trace_id128(trace_info.message_id)
tool_span_id = RandomIdGenerator().generate_span_id()
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
# Create span context with the same trace_id as the parent
# todo: Create with the appropriate parent span context, so that the tool span is
# a child of the appropriate span (e.g. message span)
span_context = SpanContext(
trace_id=trace_id,
span_id=tool_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
@@ -612,19 +610,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
context=root_span_context,
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
@@ -641,15 +634,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
@@ -663,22 +650,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
context=root_span_context,
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def ensure_root_span(self, dify_trace_id: str | None):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
@@ -698,26 +697,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}

View File

@@ -5,6 +5,7 @@ Tencent APM Trace Client - handles network operations, metrics, and API communic
from __future__ import annotations
import importlib
import json
import logging
import os
import socket
@@ -110,6 +111,7 @@ class TencentTraceClient:
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.meter_provider: MeterProvider | None = None
self.hist_llm_duration: Histogram | None = None
self.hist_token_usage: Histogram | None = None
self.hist_time_to_first_token: Histogram | None = None
@@ -119,7 +121,6 @@ class TencentTraceClient:
# Metrics exporter and instruments
try:
from opentelemetry import metrics
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
@@ -202,9 +203,11 @@ class TencentTraceClient:
)
if metric_reader is not None:
# Use instance-level MeterProvider instead of global to support config changes
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
metrics.set_meter_provider(provider)
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
self.meter_provider = provider
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
# LLM operation duration histogram
self.hist_llm_duration = self.meter.create_histogram(
@@ -244,6 +247,7 @@ class TencentTraceClient:
self.metric_reader = metric_reader
else:
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
@@ -253,6 +257,7 @@ class TencentTraceClient:
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.meter_provider = None
self.hist_llm_duration = None
self.hist_token_usage = None
self.hist_time_to_first_token = None
@@ -279,6 +284,14 @@ class TencentTraceClient:
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
LLM_OPERATION_DURATION,
latency_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
@@ -317,6 +330,13 @@ class TencentTraceClient:
"server.address": server_address,
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
GEN_AI_TOKEN_USAGE,
token_count,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
@@ -344,6 +364,13 @@ class TencentTraceClient:
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
ttft_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
@@ -371,6 +398,13 @@ class TencentTraceClient:
"stream": "true",
}
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_STREAMING_TIME_TO_GENERATE,
ttg_seconds,
json.dumps(attributes, ensure_ascii=False),
)
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
@@ -390,6 +424,14 @@ class TencentTraceClient:
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
logger.info(
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
GEN_AI_TRACE_DURATION,
duration_seconds,
json.dumps(attrs, ensure_ascii=False),
)
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
@@ -474,11 +516,19 @@ class TencentTraceClient:
if self.tracer_provider:
self.tracer_provider.shutdown()
# Shutdown instance-level meter provider
if self.meter_provider is not None:
try:
self.meter_provider.shutdown() # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
pass
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")

View File

@@ -246,7 +246,7 @@ class RequestFetchAppInfo(BaseModel):
class TriggerInvokeEventResponse(BaseModel):
variables: Mapping[str, Any] = Field(default_factory=dict)
cancelled: bool | None = False
cancelled: bool = Field(default=False)
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)

View File

@@ -147,7 +147,8 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return cast(str, info["version"]["number"])
# remove any suffix like "-SNAPSHOT" from the version string
return cast(str, info["version"]["number"]).split("-")[0]
def _check_version(self):
if parse_version(self._version) < parse_version("8.0.0"):

View File

@@ -39,11 +39,13 @@ class WeaviateConfig(BaseModel):
Attributes:
endpoint: Weaviate server endpoint URL
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
grpc_endpoint: str | None = None
api_key: str | None = None
batch_size: int = 100
@@ -88,9 +90,22 @@ class WeaviateVector(BaseVector):
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
@@ -432,6 +447,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),

View File

View File

@@ -0,0 +1,79 @@
import json
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel, ValidationError
from extensions.ext_redis import redis_client
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
class TaskWrapper(BaseModel):
data: Any
def serialize(self) -> str:
return self.model_dump_json()
@classmethod
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
return cls.model_validate_json(serialized_data)
class TenantIsolatedTaskQueue:
"""
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
It uses Redis list to store tasks, and Redis key to store task waiting flag.
Support tasks that can be serialized by json.
"""
def __init__(self, tenant_id: str, unique_key: str):
self._tenant_id = tenant_id
self._unique_key = unique_key
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
def get_task_key(self):
return redis_client.get(self._task_key)
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
redis_client.setex(self._task_key, ttl, 1)
def delete_task_key(self):
redis_client.delete(self._task_key)
def push_tasks(self, tasks: Sequence[Any]):
serialized_tasks = []
for task in tasks:
# Store str list directly, maintaining full compatibility for pipeline scenarios
if isinstance(task, str):
serialized_tasks.append(task)
else:
# Use TaskWrapper to do JSON serialization for non-string tasks
wrapper = TaskWrapper(data=task)
serialized_data = wrapper.serialize()
serialized_tasks.append(serialized_data)
redis_client.lpush(self._queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
if count <= 0:
return []
tasks = []
for _ in range(count):
serialized_task = redis_client.rpop(self._queue)
if not serialized_task:
break
if isinstance(serialized_task, bytes):
serialized_task = serialized_task.decode("utf-8")
try:
wrapper = TaskWrapper.deserialize(serialized_task)
tasks.append(wrapper.data)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
# Fall back to raw string for legacy format or invalid JSON
tasks.append(serialized_task)
return tasks

View File

@@ -210,12 +210,13 @@ class Tool(ABC):
meta=meta,
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
)
def create_variable_message(

View File

@@ -129,6 +129,7 @@ class ToolInvokeMessage(BaseModel):
class JsonMessage(BaseModel):
json_object: dict
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
blob: bytes

View File

@@ -1,16 +1,19 @@
import base64
import json
import logging
from collections.abc import Generator
from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
logger = logging.getLogger(__name__)
class MCPTool(Tool):
def __init__(
@@ -52,6 +55,11 @@ class MCPTool(Tool):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
else:
logger.warning("Unsupported content type=%s", type(content))
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
@@ -97,6 +105,10 @@ class MCPTool(Tool):
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@@ -228,29 +228,41 @@ class ToolEngine:
"""
Handle tool response
"""
result = ""
parts: list[str] = []
json_parts: list[str] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += cast(ToolInvokeMessage.TextMessage, response.message).text
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += (
parts.append(
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
parts.append(
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
if json_message.suppress_output:
continue
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
)
else:
result += str(response.message)
parts.append(str(response.message))
return result
# Add JSON parts, avoiding duplicates from text parts.
if json_parts:
existing_parts = set(parts)
parts.extend(p for p in json_parts if p not in existing_parts)
return "".join(parts)
@staticmethod
def _extract_tool_response_binary_and_text(

View File

@@ -117,7 +117,7 @@ class WorkflowTool(Tool):
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
yield self.create_json_message(outputs, suppress_output=True)
@property
def latest_usage(self) -> LLMUsage:

View File

@@ -208,7 +208,7 @@ class SubscriptionBuilder(BaseModel):
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
credentials: Mapping[str, str] = Field(..., description="The credentials of the subscription builder")
credentials: Mapping[str, Any] = Field(..., description="The credentials of the subscription builder")
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
credential_expires_at: int | None = Field(
default=None, description="The credential expires at of the subscription builder"
@@ -227,7 +227,7 @@ class SubscriptionBuilderUpdater(BaseModel):
name: str | None = Field(default=None, description="The name of the subscription builder")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
credentials: Mapping[str, str] | None = Field(
credentials: Mapping[str, Any] | None = Field(
default=None, description="The credentials of the subscription builder"
)
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")

View File

@@ -13,14 +13,14 @@ import contexts
from configs import dify_config
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError, PluginNotFoundError
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
from core.plugin.impl.trigger import PluginTriggerClient
from core.trigger.entities.entities import (
EventEntity,
Subscription,
UnsubscribeResult,
)
from core.trigger.errors import EventIgnoreError, TriggerPluginInvokeError
from core.trigger.errors import EventIgnoreError
from core.trigger.provider import PluginTriggerProviderController
from models.provider_ids import TriggerProviderID
@@ -189,13 +189,10 @@ class TriggerManager:
request=request,
payload=payload,
)
except EventIgnoreError as e:
except EventIgnoreError:
return TriggerInvokeEventResponse(variables={}, cancelled=True)
except PluginInvokeError as e:
logger.exception("Failed to invoke trigger event")
raise TriggerPluginInvokeError(
description=e.to_user_friendly_error(plugin_name=provider.entity.identity.name)
) from e
except Exception as e:
raise e
@classmethod
def subscribe_trigger(

View File

@@ -114,9 +114,45 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@@ -16,7 +16,6 @@ from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
@@ -108,8 +107,8 @@ class Worker(threading.Thread):
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
node_id="unknown",
node_type=NodeType.CODE,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),

View File

@@ -153,7 +153,11 @@ class VariablePool(BaseModel):
return None
node_id, name = self._selector_to_keys(selector)
segment: Segment | None = self.variable_dictionary[node_id].get(name)
node_map = self.variable_dictionary.get(node_id)
if node_map is None:
return None
segment: Segment | None = node_map.get(name)
if segment is None:
return None

View File

@@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

0
api/enums/__init__.py Normal file
View File

15
api/enums/cloud_plan.py Normal file
View File

@@ -0,0 +1,15 @@
from enum import StrEnum, auto
class CloudPlan(StrEnum):
"""
Enum representing user plan types in the cloud platform.
SANDBOX: Free/default plan with limited features
PROFESSIONAL: Professional paid plan
TEAM: Team collaboration paid plan
"""
SANDBOX = auto()
PROFESSIONAL = auto()
TEAM = auto()

View File

@@ -0,0 +1,134 @@
"""
Broadcast channel for Pub/Sub messaging.
"""
import types
from abc import abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager
from typing import Protocol, Self
class Subscription(AbstractContextManager["Subscription"], Protocol):
"""A subscription to a topic that provides an iterator over received messages.
The subscription can be used as a context manager and will automatically
close when exiting the context.
Note: `Subscription` instances are not thread-safe. Each thread should create its own
subscription.
"""
@abstractmethod
def __iter__(self) -> Iterator[bytes]:
"""`__iter__` returns an iterator used to consume the message from this subscription.
If the caller did not enter the context, `__iter__` may lazily perform the setup before
yielding messages; otherwise `__enter__` handles it.”
If the subscription is closed, then the returned iterator exits without
raising any error.
"""
...
@abstractmethod
def close(self) -> None:
"""close closes the subscription, releases any resources associated with it."""
...
def __enter__(self) -> Self:
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
@abstractmethod
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the broadcast channel.
If `timeout` is specified, this method returns `None` if no message is
received within the given period. If `timeout` is `None`, the call blocks
until a message is received.
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
cancel a blocking subscription.
:param timeout: timeout for receive message, in seconds.
Returns:
bytes: The received message as a byte string, or
None: If the timeout expires before a message is received.
Raises:
SubscriptionClosed: If the subscription has already been closed.
"""
...
class Producer(Protocol):
"""Producer is an interface for message publishing. It is already bound to a specific topic.
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def publish(self, payload: bytes) -> None:
"""Publish a message to the bounded topic."""
...
class Subscriber(Protocol):
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def subscribe(self) -> Subscription:
pass
class Topic(Producer, Subscriber, Protocol):
"""A named channel for publishing and subscribing to messages.
Topics provide both read and write access. For restricted access,
use as_producer() for write-only view or as_subscriber() for read-only view.
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def as_producer(self) -> Producer:
"""as_producer creates a write-only view for this topic."""
...
@abstractmethod
def as_subscriber(self) -> Subscriber:
"""as_subscriber create a read-only view for this topic."""
...
class BroadcastChannel(Protocol):
"""A broadcasting channel is a channel supporting broadcasting semantics.
Each channel is identified by a topic, different topics are isolated and do not affect each other.
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
a specific topic, all subscription should receive the published message.
There are no restriction for the persistence of messages. Once a subscription is created, it
should receive all subsequent messages published.
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def topic(self, topic: str) -> "Topic":
"""topic returns a `Topic` instance for the given topic name."""
...

View File

@@ -0,0 +1,12 @@
class BroadcastChannelError(Exception):
"""`BroadcastChannelError` is the base class for all exceptions related
to `BroadcastChannel`."""
pass
class SubscriptionClosedError(BroadcastChannelError):
"""SubscriptionClosedError means that the subscription has been closed and
methods for consuming messages should not be called."""
pass

View File

@@ -0,0 +1,3 @@
from .channel import BroadcastChannel
__all__ = ["BroadcastChannel"]

View File

@@ -0,0 +1,200 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation.
Provides "at most once" delivery semantics for messages published to channels.
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "Topic":
return Topic(self._client, topic)
class Topic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisSubscription(Subscription):
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
self._pubsub.subscribe(self._topic)
_logger.debug("Subscribed to channel %s", self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if raw_message is None:
continue
if raw_message.get("type") != "message":
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
continue
self._enqueue_message(payload_bytes)
_logger.debug("Listener thread stopped for channel %s", self._topic)
pubsub.unsubscribe(self._topic)
pubsub.close()
_logger.debug("PubSub closed for topic %s", self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
# method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None

View File

@@ -2,6 +2,8 @@ import abc
import datetime
from typing import Protocol
import pytz
class _NowFunction(Protocol):
@abc.abstractmethod
@@ -31,3 +33,51 @@ def ensure_naive_utc(dt: datetime.datetime) -> datetime.datetime:
if dt.tzinfo is None:
return dt
return dt.astimezone(datetime.UTC).replace(tzinfo=None)
def parse_time_range(
start: str | None, end: str | None, tzname: str
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
"""
Parse time range strings and convert to UTC datetime objects.
Handles DST ambiguity and non-existent times gracefully.
Args:
start: Start time string (YYYY-MM-DD HH:MM)
end: End time string (YYYY-MM-DD HH:MM)
tzname: Timezone name
Returns:
tuple: (start_datetime_utc, end_datetime_utc)
Raises:
ValueError: When time range is invalid or start > end
"""
tz = pytz.timezone(tzname)
utc = pytz.utc
def _parse(time_str: str | None, label: str) -> datetime.datetime | None:
if not time_str:
return None
try:
dt = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M").replace(second=0)
except ValueError as e:
raise ValueError(f"Invalid {label} time format: {e}")
try:
return tz.localize(dt, is_dst=None).astimezone(utc)
except pytz.AmbiguousTimeError:
return tz.localize(dt, is_dst=False).astimezone(utc)
except pytz.NonExistentTimeError:
dt += datetime.timedelta(hours=1)
return tz.localize(dt, is_dst=None).astimezone(utc)
start_dt = _parse(start, "start")
end_dt = _parse(end, "end")
# Range validation
if start_dt and end_dt and start_dt > end_dt:
raise ValueError("start must be earlier than or equal to end")
return start_dt, end_dt

View File

@@ -110,7 +110,7 @@ class Account(UserMixin, TypeBase):
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
)
role: TenantAccountRole | None = field(default=None, init=False)
@@ -250,7 +250,9 @@ class Tenant(TypeBase):
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
)
def get_accounts(self) -> list[Account]:
return list(
@@ -289,7 +291,7 @@ class TenantAccountJoin(TypeBase):
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
)
@@ -310,7 +312,7 @@ class AccountIntegrate(TypeBase):
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
)
@@ -396,5 +398,5 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
DateTime, nullable=False, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
)

View File

@@ -61,18 +61,20 @@ class Dataset(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = mapped_column(db.String(255), nullable=True)
embedding_model_provider = mapped_column(db.String(255), nullable=True)
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
embedding_model = mapped_column(sa.String(255), nullable=True)
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(JSONB, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(JSONB, nullable=True)
runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'::character varying"))
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(db.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@property
def total_documents(self):
@@ -399,7 +401,9 @@ class Document(Base):
archived_reason = mapped_column(String(255), nullable=True)
archived_by = mapped_column(StringUUID, nullable=True)
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
doc_type = mapped_column(String(40), nullable=True)
doc_metadata = mapped_column(JSONB, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
@@ -716,7 +720,9 @@ class DocumentSegment(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
@@ -881,7 +887,7 @@ class ChildChunk(Base):
)
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
@@ -1036,8 +1042,8 @@ class TidbAuthBinding(Base):
tenant_id = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status = mapped_column(String(255), nullable=False, server_default=sa.text("'CREATING'::character varying"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@@ -1088,7 +1094,9 @@ class ExternalKnowledgeApis(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def to_dict(self) -> dict[str, Any]:
return {
@@ -1141,7 +1149,9 @@ class ExternalKnowledgeBindings(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class DatasetAutoDisableLog(Base):
@@ -1197,7 +1207,7 @@ class DatasetMetadata(Base):
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@@ -1224,44 +1234,48 @@ class DatasetMetadataBinding(Base):
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
name = mapped_column(db.String(255), nullable=False)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
name = mapped_column(sa.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False)
chunk_structure = mapped_column(db.String(255), nullable=False)
chunk_structure = mapped_column(sa.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False)
yaml_content = mapped_column(sa.Text, nullable=False)
copyright = mapped_column(db.String(255), nullable=False)
privacy_policy = mapped_column(db.String(255), nullable=False)
copyright = mapped_column(sa.String(255), nullable=False)
privacy_policy = mapped_column(sa.String(255), nullable=False)
position = mapped_column(sa.Integer, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(db.String(255), nullable=False)
language = mapped_column(sa.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name = mapped_column(db.String(255), nullable=False)
name = mapped_column(sa.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False)
chunk_structure = mapped_column(db.String(255), nullable=False)
chunk_structure = mapped_column(sa.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False)
position = mapped_column(sa.Integer, nullable=False)
yaml_content = mapped_column(sa.Text, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(db.String(255), nullable=False)
language = mapped_column(sa.String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def created_user_name(self):
@@ -1273,19 +1287,21 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
class Pipeline(Base): # type: ignore[name-defined]
__tablename__ = "pipelines"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying"))
name = mapped_column(sa.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False, server_default=sa.text("''::character varying"))
workflow_id = mapped_column(StringUUID, nullable=True)
is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
@@ -1294,16 +1310,16 @@ class Pipeline(Base): # type: ignore[name-defined]
class DocumentPipelineExecutionLog(Base):
__tablename__ = "document_pipeline_execution_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
pipeline_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
datasource_type = mapped_column(db.String(255), nullable=False)
datasource_type = mapped_column(sa.String(255), nullable=False)
datasource_info = mapped_column(sa.Text, nullable=False)
datasource_node_id = mapped_column(db.String(255), nullable=False)
datasource_node_id = mapped_column(sa.String(255), nullable=False)
input_data = mapped_column(sa.JSON, nullable=False)
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@@ -1311,12 +1327,14 @@ class DocumentPipelineExecutionLog(Base):
class PipelineRecommendedPlugin(Base):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
plugin_id = mapped_column(sa.Text, nullable=False)
provider_name = mapped_column(sa.Text, nullable=False)
position = mapped_column(sa.Integer, nullable=False, default=0)
active = mapped_column(sa.Boolean, nullable=False, default=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@@ -72,3 +72,6 @@ class AppTriggerType(StrEnum):
TRIGGER_WEBHOOK = NodeType.TRIGGER_WEBHOOK.value
TRIGGER_SCHEDULE = NodeType.TRIGGER_SCHEDULE.value
TRIGGER_PLUGIN = NodeType.TRIGGER_PLUGIN.value
# for backward compatibility
UNKNOWN = "unknown"

View File

@@ -3,6 +3,7 @@ import re
import uuid
from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
@@ -94,7 +95,9 @@ class App(Base):
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
@@ -313,7 +316,9 @@ class AppModelConfig(Base):
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
opening_statement = mapped_column(sa.Text)
suggested_questions = mapped_column(sa.Text)
suggested_questions_after_answer = mapped_column(sa.Text)
@@ -544,7 +549,9 @@ class RecommendedApp(Base):
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def app(self) -> App | None:
@@ -643,7 +650,9 @@ class Conversation(Base):
read_account_id = mapped_column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
@@ -914,34 +923,42 @@ class Message(Base):
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
model_provider = mapped_column(String(255), nullable=True)
model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
override_model_configs: Mapped[str | None] = mapped_column(sa.Text)
conversation_id: Mapped[str] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False)
message: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
message_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
)
answer: Mapped[str] = mapped_column(sa.Text, nullable=False)
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
parent_message_id = mapped_column(StringUUID, nullable=True)
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price = mapped_column(sa.Numeric(10, 7))
answer_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
)
parent_message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
error = mapped_column(sa.Text)
message_metadata = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'normal'::character varying")
)
error: Mapped[str | None] = mapped_column(sa.Text)
message_metadata: Mapped[str | None] = mapped_column(sa.Text)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
@@ -1212,9 +1229,13 @@ class Message(Base):
@property
def workflow_run(self):
if self.workflow_run_id:
from .workflow import WorkflowRun
from sqlalchemy.orm import sessionmaker
return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first()
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return repo.get_workflow_run_by_id_without_tenant(run_id=self.workflow_run_id)
return None
@@ -1275,20 +1296,22 @@ class MessageFeedback(Base):
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
conversation_id = mapped_column(StringUUID, nullable=False)
message_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
content = mapped_column(sa.Text)
content: Mapped[str | None] = mapped_column(sa.Text)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def from_account(self):
def from_account(self) -> Account | None:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
@@ -1367,7 +1390,9 @@ class MessageAnnotation(Base):
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def account(self):
@@ -1432,7 +1457,9 @@ class AppAnnotationSetting(Base):
created_user_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_user_id = mapped_column(StringUUID, nullable=False)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def collection_binding_detail(self):
@@ -1460,7 +1487,9 @@ class OperationLog(Base):
content = mapped_column(sa.JSON)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class DefaultEndUserSessionID(StrEnum):
@@ -1499,7 +1528,9 @@ class EndUser(Base, UserMixin):
session_id: Mapped[str] = mapped_column()
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class AppMCPServer(Base):
@@ -1519,7 +1550,9 @@ class AppMCPServer(Base):
parameters = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@staticmethod
def generate_server_code(n: int) -> str:
@@ -1565,7 +1598,9 @@ class Site(Base):
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
code = mapped_column(String(255))
@property

View File

@@ -1,62 +1,66 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
__tablename__ = "datasource_oauth_params"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class DatasourceOauthTenantParamConfig(Base):
__tablename__ = "datasource_oauth_tenant_params"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={})
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@@ -6,7 +6,7 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .base import Base, TypeBase
from .engine import db
from .types import StringUUID
@@ -41,7 +41,7 @@ class ProviderQuotaType(StrEnum):
raise ValueError(f"No matching enum found for value '{value}'")
class Provider(Base):
class Provider(TypeBase):
"""
Provider model representing the API providers and their configurations.
"""
@@ -55,24 +55,28 @@ class Provider(Base):
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuidv7()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'custom'::character varying")
String(40), nullable=False, server_default=text("'custom'::character varying"), default="custom"
)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
quota_type: Mapped[str | None] = mapped_column(
String(40), nullable=True, server_default=text("''::character varying")
String(40), nullable=True, server_default=text("''::character varying"), default=""
)
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True)
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
def __repr__(self):
return (
@@ -135,7 +139,9 @@ class ProviderModel(Base):
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@cached_property
def credential(self):
@@ -170,7 +176,9 @@ class TenantDefaultModel(Base):
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class TenantPreferredModelProvider(Base):
@@ -185,7 +193,9 @@ class TenantPreferredModelProvider(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ProviderOrder(Base):
@@ -212,7 +222,9 @@ class ProviderOrder(Base):
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ProviderModelSetting(Base):
@@ -234,7 +246,9 @@ class ProviderModelSetting(Base):
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class LoadBalancingModelConfig(Base):
@@ -259,7 +273,9 @@ class LoadBalancingModelConfig(Base):
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ProviderCredential(Base):
@@ -279,7 +295,9 @@ class ProviderCredential(Base):
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
class ProviderModelCredential(Base):
@@ -307,4 +325,6 @@ class ProviderModelCredential(Base):
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@@ -140,8 +140,9 @@ class Workflow(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=naive_utc_now(),
server_onupdate=func.current_timestamp(),
default=func.current_timestamp(),
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", sa.Text, nullable=False, server_default="{}"
@@ -150,7 +151,7 @@ class Workflow(Base):
"conversation_variables", sa.Text, nullable=False, server_default="{}"
)
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
"rag_pipeline_variables", sa.Text, nullable=False, server_default="{}"
)
VERSION_DRAFT = "draft"
@@ -1106,7 +1107,16 @@ class WorkflowAppLog(Base):
@property
def workflow_run(self):
return db.session.get(WorkflowRun, self.workflow_run_id)
if self.workflow_run_id:
from sqlalchemy.orm import sessionmaker
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return repo.get_workflow_run_by_id_without_tenant(run_id=self.workflow_run_id)
return None
@property
def created_by_account(self):

View File

@@ -37,6 +37,7 @@ dependencies = [
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.8.72",
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.27.0",
"opentelemetry-distro==0.48b0",
"opentelemetry-exporter-otlp==1.27.0",
@@ -211,9 +212,9 @@ vdb = [
"pgvector==0.2.5",
"pymilvus~=2.5.0",
"pymochow==2.2.9",
"pyobvector~=0.2.15",
"pyobvector~=0.2.17",
"qdrant-client==1.9.0",
"tablestore==6.2.0",
"tablestore==6.3.7",
"tcvectordb~=1.6.4",
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",

View File

@@ -7,6 +7,7 @@ from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
@@ -63,7 +64,7 @@ def clean_messages():
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == "sandbox":
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False

View File

@@ -9,6 +9,7 @@ from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
@@ -35,7 +36,7 @@ def clean_unused_datasets_task():
},
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING),
"plan_filter": "sandbox",
"plan_filter": CloudPlan.SANDBOX,
"add_logs": False,
},
]

View File

@@ -7,6 +7,7 @@ from sqlalchemy import select
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
@@ -45,7 +46,7 @@ def mail_clean_document_notify_task():
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != "sandbox":
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()

View File

@@ -32,41 +32,48 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
answer = args.get("answer") or args.get("content")
if answer is None:
raise ValueError("Either 'answer' or 'content' must be provided")
if args.get("message_id"):
message_id = str(args["message_id"])
# get message info
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
if not message:
raise NotFound("Message Not Exists.")
question = args.get("question") or message.query or ""
annotation: MessageAnnotation | None = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
annotation.question = args["question"]
annotation.content = answer
annotation.question = question
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args["answer"],
question=args["question"],
content=answer,
question=question,
account_id=current_user.id,
)
else:
annotation = MessageAnnotation(
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
)
question = args.get("question")
if not question:
raise ValueError("'question' is required when 'message_id' is not provided")
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
annotation.question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,

View File

@@ -10,6 +10,7 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting import RateLimit
from enums.cloud_plan import CloudPlan
from libs.helper import RateLimiter
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
@@ -45,7 +46,7 @@ class AppGenerateService:
if dify_config.BILLING_ENABLED:
# check if it's free plan
limit_info = BillingService.get_info(app_model.tenant_id)
if limit_info["subscription"]["plan"] == "sandbox":
if limit_info["subscription"]["plan"] == CloudPlan.SANDBOX:
if cls.system_rate_limiter.is_rate_limited(app_model.tenant_id):
raise InvokeRateLimitError(
"Rate limit exceeded, please upgrade your plan "

View File

@@ -111,7 +111,9 @@ class AsyncWorkflowService:
app_id=trigger_data.app_id,
workflow_id=workflow.id,
root_node_id=trigger_data.root_node_id,
trigger_metadata=trigger_data.trigger_metadata.model_dump_json(),
trigger_metadata=(
trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}"
),
trigger_type=trigger_data.trigger_type,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps(dict(trigger_data.inputs)),

View File

@@ -4,6 +4,7 @@ from typing import Literal
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
@@ -31,7 +32,7 @@ class BillingService:
return {
"limit": knowledge_rate_limit.get("limit", 10),
"subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"),
"subscription_plan": knowledge_rate_limit.get("subscription_plan", CloudPlan.SANDBOX),
}
@classmethod

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
@@ -358,7 +359,7 @@ class ClearFreePlanTenantExpiredLogs:
try:
if (
not dify_config.BILLING_ENABLED
or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox"
or BillingService.get_info(tenant_id)["subscription"]["plan"] == CloudPlan.SANDBOX
):
# only process sandbox tenant
cls.process_tenant(flask_app, tenant_id, days, batch)

View File

@@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
@@ -49,6 +50,7 @@ from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@@ -78,7 +80,6 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
@@ -1042,7 +1043,7 @@ class DatasetService:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
if not features.billing.enabled or features.billing.subscription.plan == CloudPlan.SANDBOX:
return {
"document_ids": [],
"count": 0,
@@ -1438,7 +1439,7 @@ class DocumentService:
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -1693,7 +1694,7 @@ class DocumentService:
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
@@ -1727,7 +1728,7 @@ class DocumentService:
# count = len(website_info.urls) # type: ignore
# batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
# if features.billing.subscription.plan == "sandbox" and count > 1:
# if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
# raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
# if count > batch_upload_limit:
# raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -2196,7 +2197,7 @@ class DocumentService:
website_info = knowledge_config.data_source.info_list.website_info_list
if website_info:
count = len(website_info.urls)
if features.billing.subscription.plan == "sandbox" and count > 1:
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:

View File

@@ -0,0 +1,83 @@
import logging
from collections.abc import Callable, Sequence
from dataclasses import asdict
from functools import cached_property
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.feature_service import FeatureService
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
logger = logging.getLogger(__name__)
class DocumentIndexingTaskProxy:
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
self._tenant_id = tenant_id
self._dataset_id = dataset_id
self._document_ids = document_ids
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
@cached_property
def features(self):
return FeatureService.get_features(self._tenant_id)
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to direct queue", self._dataset_id)
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to tenant queue", self._dataset_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks(
[
asdict(
DocumentTask(
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
)
]
)
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
def _send_to_default_tenant_queue(self):
self._send_to_tenant_queue(normal_document_indexing_task)
def _send_to_priority_tenant_queue(self):
self._send_to_tenant_queue(priority_document_indexing_task)
def _send_to_priority_direct_queue(self):
self._send_to_direct_queue(priority_document_indexing_task)
def _dispatch(self):
logger.info(
"dispatch args: %s - %s - %s",
self._tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
self._dispatch()

View File

@@ -11,3 +11,7 @@ class FileTooLargeError(BaseServiceError):
class UnsupportedFileTypeError(BaseServiceError):
pass
class BlockedFileExtensionError(BaseServiceError):
description = "File extension '{extension}' is not allowed for security reasons"

View File

@@ -3,12 +3,13 @@ from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
class SubscriptionModel(BaseModel):
plan: str = "sandbox"
plan: str = CloudPlan.SANDBOX
interval: str = ""
@@ -186,7 +187,7 @@ class FeatureService:
knowledge_rate_limit.enabled = True
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
knowledge_rate_limit.limit = limit_info.get("limit", 10)
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX)
return knowledge_rate_limit
@classmethod
@@ -240,7 +241,7 @@ class FeatureService:
features.billing.subscription.interval = billing_info["subscription"]["interval"]
features.education.activated = billing_info["subscription"].get("education", False)
if features.billing.subscription.plan != "sandbox":
if features.billing.subscription.plan != CloudPlan.SANDBOX:
features.webapp_copyright_enabled = True
else:
features.is_allow_transfer_workspace = False

View File

@@ -23,7 +23,7 @@ from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
from .errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError
PREVIEW_WORDS_LIMIT = 3000
@@ -59,6 +59,10 @@ class FileService:
if len(filename) > 200:
filename = filename.split(".")[0][:200] + "." + extension
# check if extension is in blacklist
if extension and extension in dify_config.UPLOAD_FILE_EXTENSION_BLACKLIST:
raise BlockedFileExtensionError(f"File extension '.{extension}' is not allowed for security reasons")
if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
raise UnsupportedFileTypeError()

View File

@@ -0,0 +1,106 @@
import json
import logging
from collections.abc import Callable, Sequence
from functools import cached_property
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from services.feature_service import FeatureService
from services.file_service import FileService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class RagPipelineTaskProxy:
# Default uploaded file name for rag pipeline invoke entities
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
def __init__(
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
):
self._dataset_tenant_id = dataset_tenant_id
self._user_id = user_id
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
@cached_property
def features(self):
return FeatureService.get_features(self._dataset_tenant_id)
def _upload_invoke_entities(self) -> str:
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
)
return upload_file.id
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to direct queue", upload_file_id)
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to tenant queue", upload_file_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
logger.info("push tasks: %s", upload_file_id)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
logger.info("init tasks: %s", upload_file_id)
def _send_to_default_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
def _send_to_priority_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
def _send_to_priority_direct_queue(self, upload_file_id: str):
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
def _dispatch(self):
upload_file_id = self._upload_invoke_entities()
if not upload_file_id:
raise ValueError("upload_file_id is empty")
logger.info(
"dispatch args: %s - %s - %s",
self._dataset_tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different pipeline queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
self._send_to_default_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue with tenant isolation for other plans
self._send_to_priority_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue(upload_file_id)
def delay(self):
if not self._rag_pipeline_invoke_entities:
logger.warning(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
self._dataset_tenant_id,
self._user_id,
)
return
self._dispatch()

View File

@@ -126,7 +126,7 @@ workflow:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jina
provider_name: jinareader
provider_type: website_crawl
selected: false
title: Jina Reader

View File

@@ -126,7 +126,7 @@ workflow:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jina
provider_name: jinareader
provider_type: website_crawl
selected: false
title: Jina Reader

View File

@@ -419,7 +419,7 @@ workflow:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jina
provider_name: jinareader
provider_type: website_crawl
selected: false
title: Jina Reader

View File

@@ -99,7 +99,11 @@ class TriggerProviderService:
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(encrypter.mask_credentials(dict(subscription.credentials)))
subscription.credentials = dict(
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
)
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
count = workflows_in_use_map.get(subscription.id)
subscription.workflows_in_use = count if count is not None else 0

View File

@@ -22,7 +22,7 @@ class AsyncTriggerStatus(StrEnum):
class TriggerMetadata(BaseModel):
"""Trigger metadata"""
pass
type: AppTriggerType = Field(default=AppTriggerType.UNKNOWN)
class TriggerData(BaseModel):
@@ -36,7 +36,7 @@ class TriggerData(BaseModel):
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
trigger_type: AppTriggerType
trigger_from: WorkflowRunTriggeredFrom
trigger_metadata: TriggerMetadata = Field(default_factory=TriggerMetadata)
trigger_metadata: TriggerMetadata | None = None
model_config = ConfigDict(use_enum_values=True)
@@ -58,6 +58,8 @@ class ScheduleTriggerData(TriggerData):
class PluginTriggerMetadata(TriggerMetadata):
"""Plugin trigger metadata"""
type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
endpoint_id: str
plugin_unique_identifier: str
provider_id: str

View File

@@ -8,9 +8,10 @@ from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
from services.workflow.entities import TriggerMetadata
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
@@ -169,14 +170,15 @@ class WorkflowAppService:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:
return {}
icon = metadata.get("icon_filename")
icon_dark = metadata.get("icon_dark_filename")
return {
"icon": PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon) if icon else None,
"icon_dark": PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon_dark)
if icon_dark
else None,
}
trigger_metadata = TriggerMetadata.model_validate(metadata)
if trigger_metadata.type == AppTriggerType.TRIGGER_PLUGIN:
icon = metadata.get("icon_filename")
icon_dark = metadata.get("icon_dark_filename")
metadata["icon"] = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon) if icon else None
metadata["icon_dark"] = (
PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon_dark) if icon_dark else None
)
return metadata
@staticmethod
def _safe_json_loads(val):

View File

@@ -808,7 +808,11 @@ class DraftVariableSaver:
# We only save conversation variable here.
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
segment = WorkflowDraftVariable.build_segment_with_type(segment_type=item.value_type, value=item.new_value)
# Conversation variables are exposed as NUMBER in the UI even if their
# persisted type is INTEGER. Allow float updates by loosening the type
# to NUMBER here so downstream storage infers the precise subtype.
segment_type = SegmentType.NUMBER if item.value_type == SegmentType.INTEGER else item.value_type
segment = WorkflowDraftVariable.build_segment_with_type(segment_type=segment_type, value=item.new_value)
draft_vars.append(
WorkflowDraftVariable.new_conversation_variable(
app_id=self._app_id,

View File

@@ -10,20 +10,22 @@ from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@@ -32,6 +34,7 @@ from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import UserFrom
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
@@ -211,6 +214,9 @@ class WorkflowService:
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
# validate graph structure
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
@@ -267,6 +273,9 @@ class WorkflowService:
if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)
# validate graph structure
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict)
# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
@@ -896,6 +905,43 @@ class WorkflowService:
return new_app
def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]):
"""
Validate workflow graph structure by instantiating the Graph object.
This leverages the built-in graph validators (including trigger/UserInput exclusivity)
and raises any structural errors before persisting the workflow.
"""
node_configs = graph.get("nodes", [])
node_configs = cast(list[dict[str, object]], node_configs)
# is empty graph
if not node_configs:
return
workflow_id = app_model.workflow_id or "UNKNOWN"
Graph.init(
graph_config=graph,
# TODO(Mairuis): Add root node id
root_node_id=None,
node_factory=DifyNodeFactory(
graph_init_params=GraphInitParams(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=workflow_id,
graph_config=graph,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.VALIDATION,
call_depth=0,
),
graph_runtime_state=GraphRuntimeState(
variable_pool=VariablePool(),
start_at=time.perf_counter(),
),
),
)
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(

View File

@@ -48,7 +48,6 @@ def add_document_to_index_task(dataset_document_id: str):
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == False,
DocumentSegment.status == "completed",
)
.order_by(DocumentSegment.position.asc())

View File

@@ -1,11 +1,15 @@
import logging
import time
from collections.abc import Callable, Sequence
import click
from celery import shared_task
from configs import dify_config
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
@@ -21,8 +25,24 @@ def document_indexing_task(dataset_id: str, document_ids: list):
:param dataset_id:
:param document_ids:
.. warning:: TO BE DEPRECATED
This function will be deprecated and removed in a future version.
Use normal_document_indexing_task or priority_document_indexing_task instead.
Usage: document_indexing_task.delay(dataset_id, document_ids)
"""
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
_document_indexing(dataset_id, document_ids)
def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
"""
Process document for tasks
:param dataset_id:
:param document_ids:
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
@@ -38,7 +58,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -86,3 +106,63 @@ def document_indexing_task(dataset_id: str, document_ids: list):
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
try:
_document_indexing(dataset_id, document_ids)
except Exception:
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
@shared_task(queue="dataset")
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
@shared_task(queue="priority_dataset")
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Priority async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)

View File

@@ -8,6 +8,7 @@ from sqlalchemy import select
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -41,7 +42,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:

View File

@@ -12,8 +12,10 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models import Account, Tenant
@@ -22,6 +24,8 @@ from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
@@ -69,6 +73,27 @@ def priority_rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@@ -12,17 +12,20 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
@@ -70,26 +73,27 @@ def rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
redis_client.delete(tenant_pipeline_task_key)
tenant_isolated_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@@ -5,8 +5,10 @@ These tasks handle trigger workflow execution asynchronously
to avoid blocking the main request thread.
"""
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
from celery import shared_task
@@ -16,24 +18,27 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key
from core.trigger.entities.entities import TriggerProviderEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from extensions.ext_database import db
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
from models.model import EndUser
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription, WorkflowPluginTrigger
from models.workflow import Workflow
from models.trigger import TriggerSubscription, WorkflowPluginTrigger, WorkflowTriggerLog
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from services.workflow.entities import PluginTriggerData, PluginTriggerDispatchData, PluginTriggerMetadata
from services.workflow.queue_dispatcher import QueueDispatcherManager
logger = logging.getLogger(__name__)
@@ -106,6 +111,110 @@ def _get_latest_workflows_by_app_ids(
return {w.app_id: w for w in workflows}
def _record_trigger_failure_log(
*,
session: Session,
workflow: Workflow,
plugin_trigger: WorkflowPluginTrigger,
subscription: TriggerSubscription,
trigger_metadata: PluginTriggerMetadata,
end_user: EndUser | None,
error_message: str,
event_name: str,
request_id: str,
) -> None:
"""
Persist a workflow run, workflow app log, and trigger log entry for failed trigger invocations.
"""
now = datetime.now(UTC)
if end_user:
created_by_role = CreatorUserRole.END_USER
created_by = end_user.id
else:
created_by_role = CreatorUserRole.ACCOUNT
created_by = subscription.user_id
failure_inputs = {
"event_name": event_name,
"subscription_id": subscription.id,
"request_id": request_id,
"plugin_trigger_id": plugin_trigger.id,
}
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=WorkflowRunTriggeredFrom.PLUGIN.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps(failure_inputs),
status=WorkflowExecutionStatus.FAILED.value,
outputs="{}",
error=error_message,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=created_by_role.value,
created_by=created_by,
created_at=now,
finished_at=now,
exceptions_count=0,
)
session.add(workflow_run)
session.flush()
workflow_app_log = WorkflowAppLog(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
created_by_role=created_by_role.value,
created_by=created_by,
)
session.add(workflow_app_log)
dispatcher = QueueDispatcherManager.get_dispatcher(subscription.tenant_id)
queue_name = dispatcher.get_queue_name()
trigger_data = PluginTriggerData(
app_id=plugin_trigger.app_id,
tenant_id=subscription.tenant_id,
workflow_id=workflow.id,
root_node_id=plugin_trigger.node_id,
inputs={},
trigger_metadata=trigger_metadata,
plugin_id=subscription.provider_id,
endpoint_id=subscription.endpoint_id,
)
trigger_log = WorkflowTriggerLog(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
root_node_id=plugin_trigger.node_id,
trigger_metadata=trigger_metadata.model_dump_json(),
trigger_type=AppTriggerType.TRIGGER_PLUGIN,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps({}),
status=WorkflowTriggerStatus.FAILED,
error=error_message,
queue_name=queue_name,
retry_count=0,
created_by_role=created_by_role.value,
created_by=created_by,
triggered_at=now,
finished_at=now,
elapsed_time=0.0,
total_tokens=0,
)
session.add(trigger_log)
session.commit()
def dispatch_triggered_workflow(
user_id: str,
subscription: TriggerSubscription,
@@ -168,23 +277,62 @@ def dispatch_triggered_workflow(
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
# invoke triger
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
user_id=user_id,
provider_id=TriggerProviderID(subscription.provider_id),
# invoke trigger
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
parameters=node_data.resolve_parameters(
parameter_schemas=provider_controller.get_event_parameters(event_name=event_name)
),
credentials=subscription.credentials,
credential_type=CredentialType.of(subscription.credential_type),
subscription=subscription.to_entity(),
request=request,
payload=payload,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
if invoke_response.cancelled:
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
try:
invoke_response = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
user_id=user_id,
provider_id=TriggerProviderID(subscription.provider_id),
event_name=event_name,
parameters=node_data.resolve_parameters(
parameter_schemas=provider_controller.get_event_parameters(event_name=event_name)
),
credentials=subscription.credentials,
credential_type=CredentialType.of(subscription.credential_type),
subscription=subscription.to_entity(),
request=request,
payload=payload,
)
except PluginInvokeError as e:
error_message = e.to_user_friendly_error(plugin_name=trigger_entity.identity.name)
try:
end_user = end_users.get(plugin_trigger.app_id)
_record_trigger_failure_log(
session=session,
workflow=workflow,
plugin_trigger=plugin_trigger,
subscription=subscription,
trigger_metadata=trigger_metadata,
end_user=end_user,
error_message=error_message,
event_name=event_name,
request_id=request_id,
)
except Exception:
logger.exception(
"Failed to record trigger failure log for app %s",
plugin_trigger.app_id,
)
continue
except Exception:
logger.exception(
"Failed to invoke trigger event for app %s",
plugin_trigger.app_id,
)
continue
if invoke_response is not None and invoke_response.cancelled:
logger.info(
"Trigger ignored for app %s with trigger event %s",
plugin_trigger.app_id,
@@ -201,14 +349,7 @@ def dispatch_triggered_workflow(
plugin_id=subscription.provider_id,
endpoint_id=subscription.endpoint_id,
inputs=invoke_response.variables,
trigger_metadata=PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
),
trigger_metadata=trigger_metadata,
)
# Trigger async workflow

Some files were not shown because too many files have changed in this diff Show More