Compare commits

..

93 Commits

Author SHA1 Message Date
Joel
5abfe7e46f chore: to fix build error again 2025-02-19 13:56:13 +08:00
Joel
2bf4a20875 fix: build web error 2025-02-19 13:40:20 +08:00
Joel
35312cf96c fix: remove run unit test ci to ensure build successfully (#13999) 2025-02-19 12:51:35 +08:00
Joel
15f028fe59 fix: build web image fail (#13996) 2025-02-19 12:35:16 +08:00
Joel
8a2301af56 fix: fix build web image install problem (#13994) 2025-02-19 11:52:36 +08:00
Joel
66747a8eef fix: some GitHub run action problem (#13991) 2025-02-19 11:35:19 +08:00
Wu Tianwei
19d413ac1e feat: date and time picker (#13985) 2025-02-19 10:56:18 +08:00
eux
4a332ff1af fix: update the en-US i18n for logAndAnn (#13902) 2025-02-19 09:15:11 +08:00
taokuizu
dc942db52f chore: remove duplicate import statements (#13959) 2025-02-19 09:14:32 +08:00
jiangbo721
f535a2aa71 chore: prompt_message is actually assistant_message which is a bit am… (#13839)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-02-19 09:14:10 +08:00
Bowen Liang
dfdd6dfa20 fix: change the config name and fix typo in description of the number of retrieval executors (#13856) 2025-02-19 09:13:36 +08:00
github-actions[bot]
2af81d1ee3 chore: translate i18n files (#13795)
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
2025-02-19 09:13:26 +08:00
NFish
ece25bce1a fix: LLM may not return <think> tag, cause thinking time keep increase (#13962) 2025-02-18 21:39:01 +08:00
zxhlyh
6fc234183a chore: replace marketplace search api (#13963) 2025-02-18 21:37:11 +08:00
Yeuoly
15a56f705f fix: get tool provider (#13958) 2025-02-18 20:18:36 +08:00
Yeuoly
899f7e125f Fix/add or update model credentials (#13952) 2025-02-18 20:00:39 +08:00
Jyong
aa19bb3f30 fix session close issue (#13946) 2025-02-18 19:29:57 +08:00
Joel
562852a0ae fix: web test not install pnpm before use (#13931) 2025-02-18 18:18:37 +08:00
Joel
a4b992c1ab fix: web style workflow checkout error (#13929) 2025-02-18 18:18:20 +08:00
zxhlyh
3460c1dfbd fix: tool id (#13932) 2025-02-18 18:17:41 +08:00
Yeuoly
653f6c2d46 fix: fetch configured model providers (#13924) 2025-02-18 17:43:39 +08:00
zxhlyh
ed7851a4b3 fix: plugin tool icon (#13918) 2025-02-18 16:54:14 +08:00
zxhlyh
cb841e5cde fix: plugins install task (#13899) 2025-02-18 15:20:26 +08:00
Yeuoly
4dae0e514e fix: ignore plugin already exists (#13888) 2025-02-18 13:22:39 +08:00
Bowen Liang
363c46ace8 fix: add missing package xinference_client to pass vdb CI tests (#13865) 2025-02-17 23:37:49 +08:00
Charlie.Wei
abe5aca3e2 Retrieval service optimization (#13849) 2025-02-17 18:22:36 +08:00
Yeuoly
bea10b4356 fix: undefined attribute 'query' on MessageAnnotation (#13852) 2025-02-17 18:16:45 +08:00
Joel
f5f83f1924 fix: markdown merge error (#13853) 2025-02-17 18:11:07 +08:00
Yeuoly
403e2d58b9 Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: xhe <xw897002528@gmail.com>
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: takatost <takatost@gmail.com>
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: AkaraChen <akarachen@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com>
Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com>
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Novice <857526207@qq.com>
Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com>
Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com>
Co-authored-by: eux <euxuuu@gmail.com>
Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com>
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
Co-authored-by: lotsik <lotsik@mail.ru>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com>
Co-authored-by: CN-P5 <heibai2006@gmail.com>
Co-authored-by: CN-P5 <heibai2006@qq.com>
Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Boris Feld <lothiraldan@gmail.com>
Co-authored-by: mbo <himabo@gmail.com>
Co-authored-by: mabo <mabo@aeyes.ai>
Co-authored-by: Warren Chen <warren.chen830@gmail.com>
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
Co-authored-by: jiandanfeng <chenjh3@wangsu.com>
Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com>
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com>
Co-authored-by: Xu Song <xusong.vip@gmail.com>
Co-authored-by: rayshaw001 <396301947@163.com>
Co-authored-by: Ding Jiatong <dingjiatong@gmail.com>
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
Co-authored-by: JasonVV <jasonwangiii@outlook.com>
Co-authored-by: le0zh <newlight@qq.com>
Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com>
Co-authored-by: k-zaku <zaku99@outlook.jp>
Co-authored-by: luckylhb90 <luckylhb90@gmail.com>
Co-authored-by: hobo.l <hobo.l@binance.com>
Co-authored-by: jiangbo721 <365065261@qq.com>
Co-authored-by: 刘江波 <jiangbo721@163.com>
Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com>
Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: sino <sino2322@gmail.com>
Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com>
Co-authored-by: lowell <lowell.hu@zkteco.in>
Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com>
Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com>
Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com>
Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com>
Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com>
Co-authored-by: Jason <ggbbddjm@gmail.com>
Co-authored-by: Xin Zhang <sjhpzx@gmail.com>
Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com>
Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com>
Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com>
Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com>
Co-authored-by: Yingchun Lai <laiyingchun@apache.org>
Co-authored-by: Hash Brown <hi@xzd.me>
Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com>
Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com>
Co-authored-by: aplio <ryo.091219@gmail.com>
Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com>
Co-authored-by: Nam Vu <zuzoovn@gmail.com>
Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com>
Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com>
Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com>
Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com>
Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp>
Co-authored-by: HQidea <HQidea@users.noreply.github.com>
Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com>
Co-authored-by: xhe <xw897002528@gmail.com>
Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com>
Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com>
Co-authored-by: engchina <12236799+engchina@users.noreply.github.com>
Co-authored-by: engchina <atjapan2015@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: 呆萌闷油瓶 <253605712@qq.com>
Co-authored-by: Kemal <kemalmeler@outlook.com>
Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com>
Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com>
Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com>
Co-authored-by: steven <sunzwj@digitalchina.com>
Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com>
Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com>
Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com>
Co-authored-by: 胡春东 <gycm520@gmail.com>
Co-authored-by: Junjie.M <118170653@qq.com>
Co-authored-by: MuYu <mr.muzea@gmail.com>
Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com>
Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com>
Co-authored-by: Fei He <droxer.he@gmail.com>
Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com>
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
Co-authored-by: liuzhenghua <1090179900@qq.com>
Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com>
Co-authored-by: AugNSo <song.tiankai@icloud.com>
Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com>
Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com>
Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com>
Co-authored-by: Hundredwz <1808096180@qq.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2025-02-17 17:05:13 +08:00
Charlie.Wei
222df44d21 Retrieval Service efficiency optimization (#13543) 2025-02-17 14:09:57 +08:00
NFish
566e548713 fix: update trial expire time to 3/17 (#13796) 2025-02-17 10:02:21 +08:00
NFish
1434d54e7a Revert "Feat: compliance report download" (#13799) 2025-02-17 09:58:56 +08:00
Xiyuan Chen
4229d0f9a7 Revert "Feat/compliance" (#13798) 2025-02-16 20:58:25 -05:00
NFish
7f9eb35e1f Feat: compliance report download (#13282) 2025-02-17 09:43:41 +08:00
Xiyuan Chen
ed7d7a74ea Feat/compliance (#13548) 2025-02-16 20:31:52 -05:00
kurokobo
035e54ba4d fix: add install a package to improve the accuracy of guessing mime type and file extension (main) (#13752) 2025-02-16 21:39:40 +08:00
Hundredwz
284707c3a8 perf(message): optimize message loading and reduce SQL queries (#13720) 2025-02-15 12:19:01 +08:00
le0zh
1f63028a83 fix: reranking_enable setting failed #13668 (#13721) 2025-02-14 17:42:09 +08:00
Vasu Negi
8a0aa91ed7 Non-Streaming Models Do Not Return Results Properly in _handle_invoke_result (#13571)
Co-authored-by: crazywoola <427733928@qq.com>
2025-02-14 17:02:04 +08:00
呆萌闷油瓶
62079991b7 fix:Knowledge Base with Parent-Child segment mode not support in Agent (#13663) 2025-02-14 14:34:59 +08:00
jiangbo721
4e7e172ff3 Chore/format code (#13691)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-02-14 13:38:17 +08:00
llinvokerl
33a565a719 perf: Implemented short-circuit evaluation for logical conditions (#13674)
Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com>
2025-02-13 19:35:03 +08:00
Novice
f0b9257387 fix: error in obtaining end_to_node_id during conditional parallel execution (#13673) 2025-02-13 18:00:28 +08:00
jiangbo721
c398c9cb6a chore:Remove duplicate code, lines 8 to 27, same as lines 29 & 45 to 62. (#13659)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-02-13 14:51:38 +08:00
Yingchun Lai
a3d3e30e3a fix: fix tongyi models blocking mode with incremental_output=stream (#13620) 2025-02-13 10:24:05 +08:00
AugNSo
2b86465d4c fix document extractor node incorrectly processing doc and ppt files (#12902) 2025-02-12 18:04:28 +08:00
kimjion
6529240da6 fix: no longer using old app detail cover when switch pathname (#13585) 2025-02-12 15:02:11 +08:00
Bowen Liang
0751ad1eeb feat(vdb): add HNSW vector index for TiDB vector store with TiFlash (#12043) 2025-02-12 13:53:51 +08:00
Riddhimaan-Senapati
786550bdc9 fix: changed topics/keywords to topic/keywords (#13544) 2025-02-12 09:15:15 +08:00
jiangbo721
bde756a1ab chore:Remove useless brackets and format code (#13479)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-02-11 22:05:29 +08:00
Wu Jiayang
423fb2d7bc Ensure the 'inputs' field in /chat-messages takes effect every time (#7955)
Co-authored-by: Your Name <you@example.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-02-11 18:44:56 +08:00
Novice
f96b4f287a fix: iteration node log time error (#13511) 2025-02-11 16:35:21 +08:00
Novice
c00e7d3f65 fix: retry log running error (#13472)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2025-02-11 15:48:55 +08:00
yihong
1f38d4846b fix: issue #13483 and #13434 (#13518)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-02-11 12:45:49 +08:00
liuzhenghua
47a64610ca Fix the issue of repeated escaping of quotes in hit test (#13477) 2025-02-11 09:58:31 +08:00
Riddhimaan-Senapati
f0a845f0f9 fix: removed LLM output from the main README (#13504) 2025-02-11 09:09:07 +08:00
Abdullah AlOsaimi
abec23118d feat: add support for X-Forwarded-Port in ProxyFix middleware (#13102) 2025-02-10 22:28:29 +08:00
Xin Zhang
0957119550 fix: update UTC time format for consistency (#13471) 2025-02-10 19:37:50 +08:00
github-actions[bot]
f48fa3e4e8 chore: translate i18n files (#13452)
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
2025-02-10 14:14:15 +08:00
非法操作
5ffc58d6ca feat: improve think content display (#13431) 2025-02-10 14:08:17 +08:00
NFish
7d958635f0 Fix/add trial expire tip time (#13464) 2025-02-10 12:53:59 +08:00
Wu Tianwei
33990426c1 fix: add ids in FetchDatasetsParams (#13459) 2025-02-10 12:28:36 +08:00
yihong
9f3fc7ebf8 ci: make ci safe using zizmor (#13397)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-02-10 12:26:08 +08:00
ybalbert001
c8357da13b [Fix] Sagemaker LLM Provider can't adjust context size, it'a always 2… (#13462)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2025-02-10 12:25:04 +08:00
NFish
2290f14fb1 feat: add tooltip if user's anthropic trial quota still available (#13418) 2025-02-10 10:44:20 +08:00
Fei He
7796984444 Fix: Removed model params except max_token for deepseek r1 in volcengine (#13446) 2025-02-10 10:26:26 +08:00
Fei He
75113c26c6 Feat : add deepseek support for tongyi (#13445) 2025-02-10 10:26:03 +08:00
xhe
939a9ecd21 chore: use the wrap thinking api for volcengine (#13432)
Signed-off-by: xhe <xw897002528@gmail.com>
2025-02-10 10:25:07 +08:00
Summer-Gu
f307c7cd88 feat: Docker adds SSRF-related timeout settings (#13395) 2025-02-10 10:21:31 +08:00
Riddhimaan-Senapati
33ecceb90c Feat: add comparison table to main readme (#13435) 2025-02-10 10:13:46 +08:00
NFish
e0d1cab079 fix: add missed background color to iteration node (#13448) 2025-02-10 10:04:56 +08:00
Riddhimaan-Senapati
811d72a727 feat: added a _position.yaml for vertex ai provider (#13367) 2025-02-09 10:29:07 +08:00
Yi Xiao
c3c575c2e1 Fix: model selector UI hover issue (#13396) 2025-02-09 10:24:57 +08:00
海狸大師
c189629eca Fix(i18n): Refine zh-Hant workflow translations (#13421) 2025-02-09 10:24:45 +08:00
Naoki Takashima
37117c22d4 feat(model): support Gemini 2.0 Flash Lite Preview model (02-05) in Google's model provider (#13399) 2025-02-09 10:22:33 +08:00
Riddhimaan-Senapati
b05e9d2ab4 feat: update backend documentation (#13374) 2025-02-08 20:36:33 +08:00
aplio
0451333990 fix(settings): add notClearable prop to language selection (#13406) 2025-02-08 20:36:23 +08:00
MuYu
ab2e6c19a4 Fixes #13415 reset model-provider-page form value use schema.default (#13416) 2025-02-08 20:34:52 +08:00
aplio
f7959bc887 fix(chatbot): update button class to include text color for better visibility (#13411) 2025-02-08 20:34:37 +08:00
aplio
45874c699d Nitpick/fix typos in document (#13413) 2025-02-08 20:33:45 +08:00
Junjie.M
286cdc41ab reasoning model unified think tag is <think></think> (#13392)
Co-authored-by: crazywoola <427733928@qq.com>
2025-02-08 16:19:41 +08:00
Hash Brown
78708eb5d5 fix: merge conflict between #11301 and #11885 (#13391) 2025-02-08 14:38:10 +08:00
胡春东
cf36745770 fix(workflow_tool): enable File parameter support after workflow is published as a tool (#13175) 2025-02-08 12:30:00 +08:00
depy
6622c7f98d fix: Fix HTTP request node non 443 port SSL site inaccessible (#13376) 2025-02-08 12:00:45 +08:00
Hash Brown
3112b74527 fix: build failed due to getPrevChatList no longer exists (#13383) 2025-02-08 11:59:02 +08:00
Katy Tao
b3ae6b634f feat: add pan and zoom support for MiniMap (#13382) 2025-02-08 11:57:41 +08:00
Xin Zhang
982bca5d40 fix: add rate limiting to prevent brute force on password reset (#13292) 2025-02-08 10:28:31 +08:00
Kalo Chin
c8dcde6cd0 fix: Gemini 2.0 Flash 001 model yaml file naming (#13372) 2025-02-08 09:12:42 +08:00
Riddhimaan-Senapati
8f9db61688 feat: added new silicon flow models (#13369) 2025-02-08 09:12:22 +08:00
github-actions[bot]
ebdbaf34e6 chore: translate i18n files (#13349)
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
2025-02-07 22:41:25 +08:00
zhu-an
a081b1e79e fix: add compatibility config for third-party S3-compatible providers (#13354)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-02-07 22:35:24 +08:00
Steven sun
38c31e64db add enable_search parameter to qwen_max, plus, turbo (#13335)
Co-authored-by: steven <sunzwj@digitalchina.com>
2025-02-07 22:16:26 +08:00
Yi Xiao
ae6f67420c Chore: update app detail panel (#13337) 2025-02-07 18:56:43 +08:00
3507 changed files with 75060 additions and 282157 deletions

View File

@@ -1,11 +1,12 @@
#!/bin/bash
cd web && npm install
npm add -g pnpm@9.12.2
cd web && pnpm install
pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc

View File

@@ -26,6 +26,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-poetry
@@ -47,15 +50,9 @@ jobs:
- name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -P api bash dev/pytest/pytest_tools.sh
- name: Run mypy
run: |
poetry run -C api python -m mypy --install-types --non-interactive .

View File

@@ -5,8 +5,8 @@ on:
branches:
- "main"
- "deploy/dev"
tags:
- "*"
release:
types: [published]
concurrency:
group: build-push-${{ github.head_ref || github.run_id }}
@@ -79,10 +79,12 @@ jobs:
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
- name: Export digest
env:
DIGEST: ${{ steps.build.outputs.digest }}
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
sanitized_digest=${DIGEST#sha256:}
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v4
@@ -132,10 +134,15 @@ jobs:
- name: Create manifest list and push
working-directory: /tmp/digests
env:
IMAGE_NAME: ${{ env[matrix.image_name_env] }}
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env[matrix.image_name_env] }}@sha256:%s ' *)
$(printf "$IMAGE_NAME@sha256:%s " *)
- name: Inspect image
env:
IMAGE_NAME: ${{ env[matrix.image_name_env] }}
IMAGE_VERSION: ${{ steps.meta.outputs.version }}
run: |
docker buildx imagetools inspect ${{ env[matrix.image_name_env] }}:${{ steps.meta.outputs.version }}
docker buildx imagetools inspect "$IMAGE_NAME:$IMAGE_VERSION"

View File

@@ -4,6 +4,7 @@ on:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml
@@ -19,6 +20,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Poetry and Python
uses: ./.github/actions/setup-poetry

View File

@@ -9,6 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"

View File

@@ -17,6 +17,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
@@ -59,6 +62,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
@@ -66,21 +72,27 @@ jobs:
with:
files: web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 10
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint
run: pnpm run lint
docker-compose-template:
name: Docker Compose Template
@@ -89,6 +101,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
@@ -117,6 +132,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files

View File

@@ -26,16 +26,19 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
cache: ''
cache-dependency-path: 'yarn.lock'
cache-dependency-path: 'pnpm-lock.yaml'
- name: Install Dependencies
run: yarn install
run: pnpm install --frozen-lockfile
- name: Test
run: yarn test
run: pnpm test

View File

@@ -16,6 +16,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
persist-credentials: false
- name: Check for file changes in i18n/en-US
id: check_files
@@ -38,11 +39,11 @@ jobs:
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
run: pnpm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'

View File

@@ -28,6 +28,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-poetry
@@ -51,7 +54,15 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
- name: Set up Vector Store (TiDB)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: docker/tidb/docker-compose.yaml
services: |
tidb
tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
@@ -67,7 +78,9 @@ jobs:
pgvector
chroma
elasticsearch
tidb
- name: Check TiDB Ready
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: poetry run -P api bash dev/pytest/pytest_vdb.sh

View File

@@ -22,25 +22,34 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: web/**
# to run pnpm, should install package canvas, but it always install failed on amd64 under ubuntu-latest
# - name: Install pnpm
# uses: pnpm/action-setup@v4
# with:
# version: 10
# run_install: false
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache-dependency-path: ./web/package.json
# - name: Setup Node.js
# uses: actions/setup-node@v4
# if: steps.changed-files.outputs.any_changed == 'true'
# with:
# node-version: 20
# cache: pnpm
# cache-dependency-path: ./web/package.json
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
# - name: Install dependencies
# if: steps.changed-files.outputs.any_changed == 'true'
# run: pnpm install --frozen-lockfile
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn test
# - name: Run tests
# if: steps.changed-files.outputs.any_changed == 'true'
# run: pnpm test

8
.gitignore vendored
View File

@@ -163,6 +163,7 @@ docker/volumes/db/data/*
docker/volumes/redis/data/*
docker/volumes/weaviate/*
docker/volumes/qdrant/*
docker/tidb/volumes/*
docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
@@ -175,6 +176,7 @@ docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
!docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf
@@ -193,3 +195,9 @@ api/.vscode
.idea/
.vscode
# pnpm
/.pnpm-store
# plugin migrate
plugins.jsonl

View File

@@ -1,3 +0,0 @@
{
"MD024": false
}

View File

@@ -1,32 +0,0 @@
# Changelog
All notable changes to Dify will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.15.7] - 2025-04-27
### Added
- Added support for GPT-4.1 in model providers (#18912)
- Added support for Amazon Bedrock DeepSeek-R1 model (#18908)
- Added support for Amazon Bedrock Claude Sonnet 3.7 model (#18788)
- Refined version compatibility logic in app DSL service
### Fixed
- Fixed issue with creating apps from template categories (#18807, #18868)
- Fixed DSL version check when creating apps from explore templates (#18872, #18878)
## [0.15.6] - 2025-04-22
### Security
- Fixed clickjacking vulnerability (#18552)
- Fixed reset password security issue (#18366)
- Updated reset password token when email code verification succeeds (#18362)
### Fixed
- Fixed Vertex AI Gemini 2.0 Flash 001 schema (#18405)

View File

@@ -108,6 +108,72 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host
**7. Backend-as-a-Service**:
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
## Feature Comparison
<table style="width: 100%;">
<tr>
<th align="center">Feature</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>
<th align="center">Flowise</th>
<th align="center">OpenAI Assistants API</th>
</tr>
<tr>
<td align="center">Programming Approach</td>
<td align="center">API + App-oriented</td>
<td align="center">Python Code</td>
<td align="center">App-oriented</td>
<td align="center">API-oriented</td>
</tr>
<tr>
<td align="center">Supported LLMs</td>
<td align="center">Rich Variety</td>
<td align="center">Rich Variety</td>
<td align="center">Rich Variety</td>
<td align="center">OpenAI-only</td>
</tr>
<tr>
<td align="center">RAG Engine</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Agent</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Workflow</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Observability</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Enterprise Feature (SSO/Access control)</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Local Deployment</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
</table>
## Using Dify

View File

@@ -87,9 +87,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
## Feature Comparison
<table style="width: 100%;">
<tr
>
<tr>
<th align="center">Feature</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>

View File

@@ -106,6 +106,73 @@ Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-star
**7. Backend-as-a-Service**:
AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko.
## Primerjava Funkcij
<table style="width: 100%;">
<tr>
<th align="center">Funkcija</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>
<th align="center">Flowise</th>
<th align="center">OpenAI Assistants API</th>
</tr>
<tr>
<td align="center">Programski pristop</td>
<td align="center">API + usmerjeno v aplikacije</td>
<td align="center">Python koda</td>
<td align="center">Usmerjeno v aplikacije</td>
<td align="center">Usmerjeno v API</td>
</tr>
<tr>
<td align="center">Podprti LLM-ji</td>
<td align="center">Bogata izbira</td>
<td align="center">Bogata izbira</td>
<td align="center">Bogata izbira</td>
<td align="center">Samo OpenAI</td>
</tr>
<tr>
<td align="center">RAG pogon</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Agent</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Potek dela</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Spremljanje</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Funkcija za podjetja (SSO/nadzor dostopa)</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Lokalna namestitev</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
</table>
## Uporaba Dify
@@ -187,4 +254,4 @@ Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj
## Licenca
To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.

View File

@@ -1,7 +1,10 @@
.env
*.env.*
storage/generate_files/*
storage/privkeys/*
storage/tools/*
storage/upload_files/*
# Logs
logs
@@ -9,6 +12,8 @@ logs
# jetbrains
.idea
.mypy_cache
.ruff_cache
# venv
.venv

View File

@@ -409,7 +409,6 @@ MAX_VARIABLE_SIZE=204800
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@@ -422,6 +421,22 @@ POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=
# Plugin configuration
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
@@ -430,7 +445,4 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400
# Prevent Clickjacking
ALLOW_EMBED=false
LOGIN_LOCKOUT_DURATION=86400

View File

@@ -58,6 +58,8 @@ RUN \
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install a chinese font to support the use of tools like matplotlib
fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
media-types \
# install libmagic to support the use of python-magic guess MIMETYPE
libmagic1 \
&& apt-get autoremove -y \
@@ -71,6 +73,10 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
# Copy source code
COPY . /app/api/

View File

@@ -37,7 +37,13 @@
4. Create environment.
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. You can execute `poetry shell` to activate the environment.
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand]
```bash
poetry self add poetry-plugin-shell
```
Then, You can execute `poetry shell` to activate the environment.
5. Install dependencies

View File

@@ -25,6 +25,8 @@ from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
@click.command("reset-password", help="Reset the account password.")
@@ -524,7 +526,7 @@ def add_qdrant_doc_id_index(field: str):
)
)
except Exception as e:
except Exception:
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
@@ -593,7 +595,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
except Exception:
logging.exception("Failed to execute database migration")
finally:
lock.release()
@@ -639,7 +641,7 @@ where sites.id is null limit 1000"""
account = accounts[0]
print("Fixing missing site for app {}".format(app.id))
app_was_created.send(app, account=account)
except Exception as e:
except Exception:
failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
@@ -649,3 +651,69 @@ where sites.id is null limit 1000"""
break
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
def migrate_data_for_plugin():
"""
Migrate data for plugin.
"""
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
PluginDataMigration.migrate()
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
@click.command("extract-plugins", help="Extract plugins.")
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
def extract_plugins(output_file: str, workers: int):
"""
Extract plugins.
"""
click.echo(click.style("Starting extract plugins.", fg="white"))
PluginMigration.extract_plugins(output_file, workers)
click.echo(click.style("Extract plugins completed.", fg="green"))
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
@click.option(
"--output_file",
prompt=True,
help="The file to store the extracted unique identifiers.",
default="unique_identifiers.json",
)
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
def extract_unique_plugins(output_file: str, input_file: str):
"""
Extract unique plugins.
"""
click.echo(click.style("Starting extract unique plugins.", fg="white"))
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
click.echo(click.style("Extract unique plugins completed.", fg="green"))
@click.command("install-plugins", help="Install plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
def install_plugins(input_file: str, output_file: str, workers: int):
"""
Install plugins.
"""
click.echo(click.style("Starting install plugins.", fg="white"))
PluginMigration.install_plugins(input_file, output_file, workers)
click.echo(click.style("Install plugins completed.", fg="green"))

View File

@@ -134,6 +134,60 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class PluginConfig(BaseSettings):
"""
Plugin configs
"""
PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL",
default="http://localhost:5002",
)
PLUGIN_DAEMON_KEY: str = Field(
description="Plugin API key",
default="plugin-api-key",
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
description="Plugin Remote Install Host",
default="localhost",
)
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
description="Plugin Remote Install Port",
default=5003,
)
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin packages in bytes",
default=15728640,
)
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin bundles in bytes",
default=15728640 * 12,
)
class MarketplaceConfig(BaseSettings):
"""
Configuration for marketplace
"""
MARKETPLACE_ENABLED: bool = Field(
description="Enable or disable marketplace",
default=True,
)
MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL",
default="https://marketplace.dify.ai",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@@ -160,6 +214,10 @@ class EndpointConfig(BaseSettings):
default="",
)
ENDPOINT_URL_TEMPLATE: str = Field(
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)
class FileAccessConfig(BaseSettings):
"""
@@ -315,8 +373,8 @@ class HttpConfig(BaseSettings):
)
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
description="Enable or disable the X-Forwarded-For Proxy Fix middleware from Werkzeug"
" to respect X-* headers to redirect clients",
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
" when the app is behind a single trusted reverse proxy.",
default=False,
)
@@ -498,6 +556,11 @@ class AuthConfig(BaseSettings):
default=86400,
)
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""
@@ -788,6 +851,8 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
FileAccessConfig,

View File

@@ -1,3 +1,4 @@
import os
from typing import Any, Literal, Optional
from urllib.parse import quote_plus
@@ -166,6 +167,11 @@ class DatabaseConfig(BaseSettings):
default=False,
)
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count(),
)
@computed_field
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.15.7",
default="1.0.0",
)
COMMIT_SHA: str = Field(

View File

@@ -15,7 +15,7 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@@ -1,9 +1,19 @@
from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

View File

@@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
@@ -40,6 +40,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
@@ -166,4 +167,15 @@ api.add_resource(
from .tag import tags
# Import workspace controllers
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
)

View File

@@ -2,6 +2,8 @@ from functools import wraps
from flask import request
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = App.query.filter(App.id == args["app_id"]).first()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
@@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app:
recommended_app = RecommendedApp(
@@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app:
return {"result": "success"}, 204
app = App.query.filter(App.id == recommended_app.app_id).first()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
).all()
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()
for installed_app in installed_apps:
db.session.delete(installed_app)

View File

@@ -3,6 +3,8 @@ from typing import Any
import flask_restful # type: ignore
from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
@@ -26,7 +28,16 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it
def _get_resource(resource_id, tenant_id, resource_model):
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

View File

@@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
@@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200

View File

@@ -2,6 +2,7 @@ from datetime import UTC, datetime
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
@@ -50,33 +51,37 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()
site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
with Session(db.engine) as session:
site = session.query(Site).filter(Site.app_id == app_model.id).first()
for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
if not site:
raise NotFound
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.commit()
return site

View File

@@ -20,6 +20,7 @@ from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
@@ -96,6 +97,9 @@ class DraftWorkflowApi(Resource):
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
try:
@@ -139,6 +143,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="")
@@ -160,7 +167,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
@@ -178,38 +185,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -228,7 +204,44 @@ class WorkflowDraftRunIterationNodeApi(Resource):
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
@@ -246,6 +259,9 @@ class DraftWorkflowRunApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
@@ -294,13 +310,20 @@ class DraftWorkflowNodeRunApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
)
return workflow_node_execution
@@ -339,6 +362,9 @@ class PublishedWorkflowApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
@@ -376,12 +402,17 @@ class DefaultBlockConfigApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
q = args.get("q")
filters = None
if args.get("q"):
if q:
try:
filters = json.loads(args.get("q", ""))
except json.JSONDecodeError:
@@ -407,6 +438,9 @@ class ConvertToWorkflowApi(Resource):
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")

View File

@@ -59,3 +59,9 @@ class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
error_code = "email_code_account_deletion_rate_limit_exceeded"
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
code = 429
class EmailPasswordResetLimitError(BaseHTTPException):
error_code = "email_password_reset_limit"
description = "Too many failed password reset attempts. Please try again in 24 hours."
code = 429

View File

@@ -3,12 +3,20 @@ import secrets
from flask import request
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
@@ -22,7 +30,6 @@ from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@@ -38,7 +45,8 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
account = Account.query.filter_by(email=args["email"]).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
@@ -54,7 +62,6 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@@ -64,6 +71,10 @@ class ForgotPasswordCheckApi(Resource):
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
@@ -72,22 +83,15 @@ class ForgotPasswordCheckApi(Resource):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email")}
class ForgotPasswordResetApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
@@ -106,9 +110,6 @@ class ForgotPasswordResetApi(Resource):
if reset_data is None:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
AccountService.revoke_reset_password_token(token)
@@ -118,7 +119,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
if account:
account.password = base64_password_hashed
account.password_salt = base64_salt
@@ -139,7 +141,7 @@ class ForgotPasswordResetApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
pass
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
return {"result": "success"}

View File

@@ -22,7 +22,7 @@ from controllers.console.error import (
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
@@ -38,7 +38,6 @@ class LoginApi(Resource):
"""Resource for user login."""
@setup_required
@email_password_login_enabled
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
@@ -111,7 +110,6 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")

View File

@@ -5,6 +5,8 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
if not account:
account = Account.query.filter_by(email=user_info.email).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account

View File

@@ -4,6 +4,8 @@ import json
from flask import request
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import api
@@ -76,7 +78,10 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
@@ -108,47 +113,53 @@ class DataSourceNotionListApi(Resource):
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = []
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = Document.query.filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
with Session(db.engine) as session:
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = session.execute(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = session.scalars(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource):
@@ -158,14 +169,17 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id)
page_id = str(page_id)
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding:
raise NotFound("Data source binding not found.")

View File

@@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, enterpris
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -72,7 +73,9 @@ class DatasetListApi(Resource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality":
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True

View File

@@ -7,7 +7,6 @@ from flask import request
from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import asc, desc
from transformers.hf_argparser import string_to_bool # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -40,6 +39,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -150,8 +150,20 @@ class DatasetDocumentListApi(Resource):
sort = request.args.get("sort", default="-created_at", type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch = string_to_bool(request.args.get("fetch", default="false"))
except (ArgumentTypeError, ValueError, Exception) as e:
fetch_val = request.args.get("fetch", default="false")
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
fetch = True
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
fetch = False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
f"(case insensitive)."
)
except (ArgumentTypeError, ValueError, Exception):
fetch = False
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@@ -429,6 +441,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
@@ -529,6 +543,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))

View File

@@ -2,8 +2,11 @@ import os
from flask import session
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first()
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True

View File

@@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
from . import api
@@ -52,8 +52,9 @@ class SetupApi(Resource):
def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED":
return DifySetup.query.first()
return True
return db.session.query(DifySetup).first()
else:
return True
api.add_resource(SetupApi, "/setup")

View File

@@ -0,0 +1,56 @@
from functools import wraps
from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
user = current_user
tenant_id = user.current_tenant_id
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
)
if not permission:
# no permission set, allow access for everyone
return view(*args, **kwargs)
if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)
return decorated
return interceptor

View File

@@ -0,0 +1,36 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.agent_service import AgentService
class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")

View File

@@ -0,0 +1,205 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
class EndpointCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
}
class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()
page = args["page"]
page_size = args["page_size"]
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
)
}
)
class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args")
args = parser.parse_args()
page = args["page"]
page_size = args["page_size"]
plugin_id = args["plugin_id"]
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
)
}
)
class EndpointDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
if not user.is_admin_or_owner:
raise Forbidden()
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
class EndpointUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
}
class EndpointEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
class EndpointDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")

View File

@@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource):
response = {"result": "success" if result else "error"}
if not result:
response["error"] = error
response["error"] = error or "Unknown error"
return response
@@ -125,9 +125,10 @@ class ModelProviderIconApi(Resource):
Get model provider icon
"""
def get(self, provider: str, icon_type: str, lang: str):
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
tenant_id=tenant_id,
provider=provider,
icon_type=icon_type,
lang=lang,
@@ -183,53 +184,17 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
return data
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
)
return result
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
)
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
)
api.add_resource(
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
)
api.add_resource(
ModelProviderFreeQuotaQualificationVerifyApi,
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource):
response = {"result": "success" if result else "error"}
if not result:
response["error"] = error
response["error"] = error or ""
return response
@@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource):
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<string:provider>/models/enable",
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<string:provider>/models/disable",
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@@ -0,0 +1,475 @@
import io
from flask import request, send_file
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.manager.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
try:
plugins = PluginService.list(tenant_id)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
class PluginIconApi(Resource):
@setup_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("tenant_id", type=str, required=True, location="args")
req.add_argument("filename", type=str, required=True, location="args")
args = req.parse_args()
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
file = request.files["pkg"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
file = request.files["bundle"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try:
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()
try:
response = PluginService.install_from_github(
tenant_id,
args["plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
args["original_plugin_unique_identifier"],
args["new_plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("install_permission", type=str, required=True, location="json")
req.add_argument("debug_permission", type=str, required=True, location="json")
args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
return jsonable_encoder(
{
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
)
return jsonable_encoder(
{
"install_permission": permission.install_permission,
"debug_permission": permission.debug_permission,
}
)
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

View File

@@ -25,8 +25,10 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
@@ -47,28 +49,43 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
tenant_id = user.current_tenant_id
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
)
)
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
@@ -82,11 +99,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -131,11 +150,13 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -168,6 +189,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
@@ -175,8 +201,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
args = parser.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id,
current_user.current_tenant_id,
user_id,
tenant_id,
args["url"],
)
@@ -186,8 +212,10 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -209,11 +237,13 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -248,11 +278,13 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -272,8 +304,10 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -293,7 +327,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
class ToolApiProviderSchemaApi(Resource):
@@ -344,11 +382,13 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -381,11 +421,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -421,11 +463,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -444,8 +488,10 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@@ -476,8 +522,10 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@@ -498,8 +546,10 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -517,8 +567,10 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -536,8 +588,10 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -563,16 +617,18 @@ class ToolLabelsApi(Resource):
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")

View File

@@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
@@ -134,9 +135,13 @@ def setup_required(view):
@wraps(view)
def decorated(*args, **kwargs):
# check setup
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
raise NotSetupError()
return view(*args, **kwargs)
@@ -154,16 +159,3 @@ def enterprise_license_required(view):
return view(*args, **kwargs)
return decorated
def email_password_login_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if features.enable_email_password_login:
return view(*args, **kwargs)
# otherwise, return 403
abort(403)
return decorated

View File

@@ -6,4 +6,4 @@ bp = Blueprint("files", __name__)
api = ExternalApi(bp)
from . import image_preview, tool_files
from . import image_preview, tool_files, upload

View File

@@ -0,0 +1,69 @@
from flask import request
from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console.wraps import setup_required
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
from fields.file_fields import file_fields
from services.file_service import FileService
class PluginUploadFileApi(Resource):
@setup_required
@marshal_with(file_fields)
def post(self):
# get file from request
file = request.files["file"]
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
tenant_id = request.args.get("tenant_id")
if not tenant_id:
raise Forbidden("Invalid request.")
user_id = request.args.get("user_id")
user = get_user(tenant_id, user_id)
filename = file.filename
mimetype = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
if not timestamp or not nonce or not sign:
raise Forbidden("Invalid request.")
if not verify_plugin_file_signature(
filename=filename,
mimetype=mimetype,
tenant_id=tenant_id,
user_id=user_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
):
raise Forbidden("Invalid request.")
try:
upload_file = FileService.upload_file(
filename=filename,
content=file.read(),
mimetype=mimetype,
user=user,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

View File

@@ -5,4 +5,5 @@ from libs.external_api import ExternalApi
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
api = ExternalApi(bp)
from .plugin import plugin
from .workspace import workspace

View File

@@ -0,0 +1,293 @@
from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.file.helpers import get_signed_file_url_for_plugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeSummary,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response
from models.account import Account, Tenant
from models.model import EndUser
class PluginInvokeLLMApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_text_embedding(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeRerankApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_rerank(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeTTSApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
def generator():
response = PluginModelBackwardsInvocation.invoke_tts(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
class PluginInvokeSpeech2TextApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_speech2text(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeModerationApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_moderation(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeToolApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
def generator():
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_model.id,
tool_type=ToolProviderType.value_of(payload.tool_type),
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
),
)
return compact_generate_response(generator())
class PluginInvokeParameterExtractorNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
tenant_id=tenant_model.id,
user_id=user_model.id,
parameters=payload.parameters,
model_config=payload.model,
instruction=payload.instruction,
query=payload.query,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeQuestionClassifierNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
tenant_id=tenant_model.id,
user_id=user_model.id,
query=payload.query,
model_config=payload.model,
classes=payload.classes,
instruction=payload.instruction,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeAppApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_model.id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.response_mode == "streaming",
inputs=payload.inputs,
files=payload.files,
)
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
class PluginInvokeEncryptApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
"""
encrypt or decrypt data
"""
try:
return BaseBackwardsInvocationResponse(
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginInvokeSummaryApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
try:
return BaseBackwardsInvocationResponse(
data={
"summary": PluginModelBackwardsInvocation.invoke_summary(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
}
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginUploadFileRequestApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
# generate signed url
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")

View File

@@ -0,0 +1,116 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from flask import request
from flask_restful import reqparse # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import Account, Tenant
from models.model import EndUser
from services.account_service import AccountService
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
try:
with Session(db.engine) as session:
if not user_id:
user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
session_id=user_id,
)
session.add(user_model)
session.commit()
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
raise ValueError("user not found")
return user_model
def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")
kwargs = parser.parse_args()
user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
if not tenant_id:
raise ValueError("tenant_id is required")
if not user_id:
user_id = "DEFAULT-USER"
del kwargs["tenant_id"]
del kwargs["user_id"]
try:
tenant_model = (
db.session.query(Tenant)
.filter(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")
try:
payload = payload_type(**data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only
from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created
from models.account import Account
from services.account_service import TenantService
@@ -12,7 +12,7 @@ from services.account_service import TenantService
class EnterpriseWorkspace(Resource):
@setup_required
@inner_api_only
@enterprise_inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
@@ -33,7 +33,7 @@ class EnterpriseWorkspace(Resource):
class EnterpriseWorkspaceNoOwnerEmail(Resource):
@setup_required
@inner_api_only
@enterprise_inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
@@ -50,8 +50,8 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
"plan": tenant.plan,
"status": tenant.status,
"custom_config": json.loads(tenant.custom_config) if tenant.custom_config else {},
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
"updated_at": tenant.updated_at.isoformat() if tenant.updated_at else None,
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
}
return {

View File

@@ -10,7 +10,7 @@ from extensions.ext_database import db
from models.model import EndUser
def inner_api_only(view):
def enterprise_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
@@ -18,7 +18,7 @@ def inner_api_only(view):
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(401)
return view(*args, **kwargs)
@@ -26,7 +26,7 @@ def inner_api_only(view):
return decorated
def inner_api_user_auth(view):
def enterprise_inner_api_user_auth(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
@@ -60,3 +60,19 @@ def inner_api_user_auth(view):
return view(*args, **kwargs)
return decorated
def plugin_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(404)
return view(*args, **kwargs)
return decorated

View File

@@ -10,6 +10,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields
from fields.message_fields import feedback_fields, retriever_resource_fields
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
@@ -18,26 +19,6 @@ from services.message_service import MessageService
class MessageListApi(Resource):
feedback_fields = {"rating": fields.String}
retriever_resource_fields = {
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
@@ -89,7 +70,7 @@ class MessageListApi(Resource):
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@@ -21,7 +21,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
from fields.raws import FilesContainedField
from libs import helper
from libs.helper import TimestampField, uuid_value
@@ -34,27 +34,6 @@ from services.message_service import MessageService
class MessageListApi(WebApiResource):
feedback_fields = {"rating": fields.String}
retriever_resource_fields = {
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
message_fields = {
"id": fields.String,
"conversation_id": fields.String,

View File

@@ -1,7 +1,6 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
@@ -32,19 +31,16 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
)
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@@ -62,11 +58,9 @@ class BaseAgentRunner(AppRunner):
queue_manager: AppQueueManager,
message: Message,
user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@@ -79,8 +73,6 @@ class BaseAgentRunner(AppRunner):
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback
@@ -141,11 +133,10 @@ class BaseAgentRunner(AppRunner):
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
)
tool_entity.load_variables(self.variables_pool)
assert tool_entity.entity.description
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.description.llm if tool_entity.description else "",
description=tool_entity.entity.description.llm,
parameters={
"type": "object",
"properties": {},
@@ -153,7 +144,7 @@ class BaseAgentRunner(AppRunner):
},
)
parameters = tool_entity.get_all_runtime_parameters()
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
@@ -186,9 +177,11 @@ class BaseAgentRunner(AppRunner):
"""
convert dataset retriever tool to prompt message tool
"""
assert tool.entity.description
prompt_tool = PromptMessageTool(
name=tool.identity.name if tool.identity else "unknown",
description=tool.description.llm if tool.description else "",
name=tool.entity.identity.name,
description=tool.entity.description.llm,
parameters={
"type": "object",
"properties": {},
@@ -234,8 +227,7 @@ class BaseAgentRunner(AppRunner):
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
if dataset_tool.identity is not None:
tool_instances[dataset_tool.identity.name] = dataset_tool
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
@@ -320,24 +312,24 @@ class BaseAgentRunner(AppRunner):
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
tool_name: str | None,
tool_input: Union[str, dict, None],
thought: str | None,
observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None],
answer: str,
answer: str | None,
messages_ids: list[str],
llm_usage: LLMUsage | None = None,
):
"""
Save agent thought
"""
queried_thought = (
updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
)
if not queried_thought:
raise ValueError(f"Agent thought {agent_thought.id} not found")
agent_thought = queried_thought
if not updated_agent_thought:
raise ValueError("agent thought not found")
agent_thought = updated_agent_thought
if thought:
agent_thought.thought = thought
@@ -349,39 +341,39 @@ class BaseAgentRunner(AppRunner):
if isinstance(tool_input, dict):
try:
tool_input = json.dumps(tool_input, ensure_ascii=False)
except Exception as e:
except Exception:
tool_input = json.dumps(tool_input)
agent_thought.tool_input = tool_input
updated_agent_thought.tool_input = tool_input
if observation:
if isinstance(observation, dict):
try:
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
except Exception:
observation = json.dumps(observation)
agent_thought.observation = observation
updated_agent_thought.observation = observation
if answer:
agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
updated_agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
updated_agent_thought.message_token = llm_usage.prompt_tokens
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
updated_agent_thought.answer_token = llm_usage.completion_tokens
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
updated_agent_thought.tokens = llm_usage.total_tokens
updated_agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
labels = updated_agent_thought.tool_labels or {}
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
for tool in tools:
if not tool:
continue
@@ -392,42 +384,20 @@ class BaseAgentRunner(AppRunner):
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
agent_thought.tool_labels_str = json.dumps(labels)
updated_agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception as e:
except Exception:
tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta
updated_agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
if not queried_variables:
return
db_variables = queried_variables
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize agent history
@@ -464,11 +434,11 @@ class BaseAgentRunner(AppRunner):
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools:
@@ -515,7 +485,11 @@ class BaseAgentRunner(AppRunner):
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if not files:
return UserPromptMessage(content=message.query)
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if message.app_model_config:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
else:
file_extra_config = None
if not file_extra_config:
return UserPromptMessage(content=message.query)

View File

@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from core.agent.base_agent_runner import BaseAgentRunner
@@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from models.model import Message
@@ -27,11 +27,11 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] | None = None
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
_instruction: str = "" # FIXME this must be str for now
_query: str | None = None
_prompt_messages_tools: list[PromptMessageTool] = []
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
@@ -42,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
@@ -54,17 +55,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
# convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
@@ -104,6 +107,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@@ -115,14 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
if not isinstance(chunks, Generator):
raise ValueError("Expected streaming response from LLM")
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
usage_dict: dict[str, Optional[LLMUsage]] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
@@ -142,25 +139,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
if scratchpad.agent_response is not None:
scratchpad.agent_response += json.dumps(chunk.model_dump())
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
if scratchpad.agent_response is not None:
scratchpad.agent_response += chunk
if scratchpad.thought is not None:
scratchpad.thought += chunk
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
if scratchpad.thought is not None:
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
if self._agent_scratchpad is not None:
self._agent_scratchpad.append(scratchpad)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
@@ -255,8 +252,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
answer=final_answer,
messages_ids=[],
)
if self.variables_pool is not None and self.db_variables_pool is not None:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -274,7 +269,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
@@ -314,11 +309,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
)
# publish files
for message_file_id, save_as in message_files:
if save_as is not None and self.variables_pool:
# FIXME the save_as type is confusing, it should be a string or not
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
@@ -341,7 +332,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e:
except Exception:
continue
return instruction
@@ -378,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(
self, current_session_messages: Optional[list[PromptMessage]] = None
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
@@ -390,8 +381,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
if not isinstance(message.content, str | None):
raise NotImplementedError("expected str type")
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
@@ -410,9 +400,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
except:
pass
elif isinstance(message, ToolPromptMessage):
if not current_scratchpad:
continue
if isinstance(message.content, str):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")

View File

@@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
"""
Organize system prompt
"""
if not self.app_config.agent:
raise ValueError("Agent configuration is not set")
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
@@ -83,8 +83,10 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"

View File

@@ -1,18 +1,21 @@
from enum import Enum
from typing import Any, Literal, Optional, Union
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api", "workflow"]
provider_type: ToolProviderType
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
plugin_unique_identifier: str | None = None
class AgentPromptEntity(BaseModel):
@@ -66,7 +69,7 @@ class AgentEntity(BaseModel):
Agent Entity.
"""
class Strategy(Enum):
class Strategy(StrEnum):
"""
Agent Strategy.
"""
@@ -78,5 +81,13 @@ class AgentEntity(BaseModel):
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] | None = None
tools: Optional[list[AgentToolEntity]] = None
max_iteration: int = 5
class AgentInvokeMessage(ToolInvokeMessage):
"""
Agent Invoke Message.
"""
pass

View File

@@ -46,18 +46,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@@ -84,6 +86,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@@ -106,7 +109,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None
if self.stream_tool_call and isinstance(chunks, Generator):
if isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@@ -123,7 +126,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
@@ -139,7 +142,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = chunk.delta.usage
yield chunk
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
else:
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
@@ -150,7 +153,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
@@ -182,8 +185,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
usage=result.usage,
),
)
else:
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
@@ -242,15 +243,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id, save_as in message_files:
if save_as:
if self.variables_pool:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as
)
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
@@ -302,8 +300,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
if self.variables_pool and self.db_variables_pool:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -334,9 +330,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract tool calls from llm result chunk
@@ -359,7 +353,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract blocking tool calls from llm result
@@ -382,9 +376,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def _init_system_message(
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""

View File

@@ -0,0 +1,89 @@
import enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.plugin.entities.parameters import (
PluginParameter,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolProviderIdentity,
)
class AgentStrategyProviderIdentity(ToolProviderIdentity):
"""
Inherits from ToolProviderIdentity, without any additional fields.
"""
pass
class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
"""
Keep all the types from PluginParameterType
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class AgentStrategyProviderEntity(BaseModel):
identity: AgentStrategyProviderIdentity
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
class AgentStrategyIdentity(ToolIdentity):
"""
Inherits from ToolIdentity, without any additional fields.
"""
pass
class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
return v or []
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
strategies: list[AgentStrategyEntity] = Field(default_factory=list)

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
class BaseAgentStrategy(ABC):
"""
Agent Strategy
"""
def invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
Get the parameters for the agent strategy.
"""
return []
@abstractmethod
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

View File

@@ -0,0 +1,59 @@
from collections.abc import Generator, Sequence
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.manager.agent import PluginAgentManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
class PluginAgentStrategy(BaseAgentStrategy):
"""
Agent Strategy
"""
tenant_id: str
declaration: AgentStrategyEntity
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
self.tenant_id = tenant_id
self.declaration = declaration
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
"""
Initialize the parameters for the agent strategy.
"""
for parameter in self.declaration.parameters:
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
manager = PluginAgentManager()
initialized_params = self.initialize_parameters(params)
params = convert_parameters_to_plugin_format(initialized_params)
yield from manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
agent_provider=self.declaration.identity.provider,
agent_strategy=self.declaration.identity.name,
agent_params=params,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)

View File

@@ -4,7 +4,8 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager
@@ -63,14 +64,14 @@ class ModelConfigConverter:
stop = completion_params["stop"]
del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_mode = LLMMode.CHAT.value
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@@ -2,8 +2,9 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
@@ -53,9 +54,18 @@ class ModelConfigManager:
raise ValueError("model must be of object type")
# model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name

View File

@@ -37,17 +37,6 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
@@ -65,20 +54,31 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@@ -140,9 +140,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
@@ -156,6 +154,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_run_id=workflow_run_id,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
workflow=workflow,
@@ -167,8 +167,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
) -> Mapping[str, Any] | Generator[str, None, None]:
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@@ -205,6 +211,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
workflow=workflow,
@@ -224,7 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.

View File

@@ -56,7 +56,7 @@ def _process_future(
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
@@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher:
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
self.voices = self.model_instance.get_tts_voices(language=language)
values = [voice.get("value") for voice in self.voices]
self.voice = voice
if not voice or voice not in values:

View File

@@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
else:
inputs = self.application_generate_entity.inputs

View File

@@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from typing import Any, cast
@@ -58,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
) -> Generator[dict | str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -84,12 +83,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
) -> Generator[dict | str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -123,4 +122,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk

View File

@@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
@@ -219,7 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline:
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@@ -247,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
except Exception:
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break
if tts_publisher:
@@ -640,6 +643,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session.commit()
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@@ -1,3 +1,4 @@
import contextvars
import logging
import threading
import uuid
@@ -29,17 +30,6 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
@@ -51,6 +41,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
@@ -60,7 +61,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
def generate(
self,
@@ -70,7 +71,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.
@@ -148,9 +149,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
@@ -182,6 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": contextvars.copy_context(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -206,6 +206,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(
self,
flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
@@ -220,6 +221,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID
:return:
"""
for var, val in context.items():
var.set(val)
with flask_app.app_context():
try:
# get conversation and message

View File

@@ -8,18 +8,16 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
@@ -55,6 +53,20 @@ class AgentChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
@@ -72,8 +84,8 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
inputs=dict(inputs),
files=list(files),
query=query,
memory=memory,
)
@@ -85,8 +97,8 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id,
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
inputs=dict(inputs),
query=query or "",
message_id=message.id,
)
except ModerationError as e:
@@ -142,9 +154,9 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
inputs=dict(inputs),
files=list(files),
query=query or "",
memory=memory,
)
@@ -159,16 +171,7 @@ class AgentChatAppRunner(AppRunner):
return
agent_entity = app_config.agent
if not agent_entity:
raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
assert agent_entity is not None
# init model instance
model_instance = ModelInstance(
@@ -179,9 +182,9 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
inputs=dict(inputs),
files=list(files),
query=query or "",
memory=memory,
)
@@ -229,8 +232,6 @@ class AgentChatAppRunner(AppRunner):
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance,
)
@@ -247,73 +248,3 @@ class AgentChatAppRunner(AppRunner):
stream=application_generate_entity.stream,
agent=True,
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables:
# save tool variables to session, so that we can update it later
db.session.add(tool_variables)
else:
# create new tool variables
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
)

View File

@@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@@ -51,10 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -80,13 +79,12 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk
@classmethod
def convert_stream_simple_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -118,4 +116,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk

View File

@@ -14,21 +14,15 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls,
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
invoke_from: InvokeFrom,
) -> Mapping[str, Any] | Generator[str, None, None]:
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
return _generate_full_response()
else:
@@ -36,12 +30,8 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
return _generate_simple_response()
@@ -59,14 +49,14 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod

View File

@@ -1,5 +1,6 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import VariableEntityType
from core.file import File, FileUploadConfig
@@ -138,3 +139,21 @@ class BaseAppGenerator:
if isinstance(value, str):
return value.replace("\x00", "")
return value
@classmethod
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
"""
Convert messages into event stream
"""
if isinstance(generator, dict):
return generator
else:
def gen():
for message in generator:
if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"
return gen()

View File

@@ -2,7 +2,7 @@ import queue
import time
from abc import abstractmethod
from enum import Enum
from typing import Any
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@@ -115,7 +115,7 @@ class AppQueueManager:
Set task stop flag
:return:
"""
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return

View File

@@ -15,8 +15,10 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
@@ -29,6 +31,106 @@ if TYPE_CHECKING:
class AppRunner:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:return:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query,
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(
self,
app_record: App,

View File

@@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
@@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
@@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Generate App response.
@@ -141,9 +141,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,

View File

@@ -50,6 +50,20 @@ class ChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
@@ -180,6 +194,9 @@ class ChatAppRunner(AppRunner):
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,

View File

@@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@@ -52,9 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -80,13 +79,12 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -118,4 +116,4 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk

View File

@@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload
def generate(
@@ -56,8 +56,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate(
self,
@@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.

View File

@@ -43,6 +43,20 @@ class CompletionAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages(
@@ -138,6 +152,9 @@ class CompletionAppRunner(AppRunner):
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,

View File

@@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
@@ -51,9 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -78,13 +77,12 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -115,4 +113,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk

View File

@@ -36,13 +36,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
@@ -50,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ...
@overload
@@ -64,26 +64,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
@@ -124,7 +124,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
app_model=app_model,
@@ -146,7 +149,18 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@@ -185,10 +199,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
@@ -224,6 +238,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate(
app_model=app_model,

View File

@@ -1,9 +1,9 @@
import json
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@@ -36,9 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -62,13 +61,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -94,4 +92,4 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk

View File

@@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
@@ -190,7 +191,9 @@ class WorkflowAppGenerateTaskPipeline:
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@@ -527,6 +530,10 @@ class WorkflowAppGenerateTaskPipeline:
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -27,6 +28,7 @@ from core.app.entities.queue_entities import (
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
@@ -239,6 +241,7 @@ class WorkflowBasedAppRunner(AppRunner):
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
)
)
elif isinstance(event, NodeRunSucceededEvent):
@@ -373,6 +376,19 @@ class WorkflowBasedAppRunner(AppRunner):
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(

View File

@@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping
single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@@ -6,7 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -41,6 +41,7 @@ class QueueEvent(StrEnum):
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
STOP = "stop"
@@ -280,6 +281,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
class QueueNodeSucceededEvent(AppQueueEvent):
@@ -315,6 +317,22 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueAgentLogEvent(AppQueueEvent):
"""
QueueAgentLogEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_LOG
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
class QueueNodeRetryEvent(QueueNodeStartedEvent):
"""QueueNodeRetryEvent entity"""

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from models.workflow import WorkflowNodeExecutionStatus
@@ -60,6 +61,7 @@ class StreamEvent(Enum):
ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
class StreamResponse(BaseModel):
@@ -247,6 +249,7 @@ class NodeStartStreamResponse(StreamResponse):
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None
agent_strategy: Optional[AgentNodeStrategyInit] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@@ -696,3 +699,26 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
node_execution_id: str
id: str
label: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
event: StreamEvent = StreamEvent.AGENT_LOG
data: Data

View File

@@ -24,6 +24,8 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(model_config, text)
moderation_result = moderation.check_moderation(
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
)
return moderation_result

View File

@@ -215,7 +215,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(publisher, task_id)

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -24,6 +25,7 @@ from core.app.entities.queue_entities import (
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@@ -320,9 +322,8 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
@@ -540,6 +541,7 @@ class WorkflowCycleManage:
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
@@ -842,4 +844,25 @@ class WorkflowCycleManage:
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution
return session.merge(cached_workflow_node_execution)
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
),
)

View File

@@ -1,4 +1,4 @@
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping
from typing import Any, Optional, TextIO, Union
from pydantic import BaseModel
@@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel):
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Sequence[ToolInvokeMessage] | str,
tool_outputs: Iterable[ToolInvokeMessage] | str,
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,

View File

@@ -1,5 +1,26 @@
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from collections.abc import Generator, Iterable, Mapping
from typing import Any, Optional
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""
def on_tool_execution(
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Iterable[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[ToolInvokeMessage, None, None]:
for tool_output in tool_outputs:
print_text("\n[on_tool_execution]\n", color=self.color)
print_text("Tool: " + tool_name + "\n", color=self.color)
print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color)
print_text("\n")
yield tool_output

Some files were not shown because too many files have changed in this diff Show More