mirror of
https://github.com/langgenius/dify.git
synced 2026-04-09 05:31:24 +08:00
Compare commits
715 Commits
feat/trigg
...
feat/ga-tr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8adb882c4b | ||
|
|
56abf7313a | ||
|
|
c7a63ed01a | ||
|
|
2b8be2c869 | ||
|
|
897458fbe1 | ||
|
|
f31519067d | ||
|
|
01cfe2c197 | ||
|
|
fa7a0fd3d1 | ||
|
|
2853ee4a68 | ||
|
|
2a20a96123 | ||
|
|
0ae2d2b631 | ||
|
|
da2421f378 | ||
|
|
02b11c2fe3 | ||
|
|
ddd643961f | ||
|
|
eef4223063 | ||
|
|
5824ab766c | ||
|
|
613b6cb035 | ||
|
|
197ba5f6ef | ||
|
|
692cd37402 | ||
|
|
7b424d6270 | ||
|
|
89447930b6 | ||
|
|
461078a0cb | ||
|
|
cb9148a39a | ||
|
|
cdd85ff736 | ||
|
|
07a4d7a182 | ||
|
|
56dd17cdaa | ||
|
|
92ab453a5a | ||
|
|
2cd322af21 | ||
|
|
97f0dff9e1 | ||
|
|
1205960c59 | ||
|
|
8254a40e75 | ||
|
|
ef043c6906 | ||
|
|
7e21e8ab01 | ||
|
|
a5ec05acbc | ||
|
|
83f2e14249 | ||
|
|
d8fe0c916b | ||
|
|
69e3ccaef0 | ||
|
|
2c2b3092f6 | ||
|
|
4a797ab2d8 | ||
|
|
63b74c17df | ||
|
|
1058961fd8 | ||
|
|
b340234711 | ||
|
|
79200d89d8 | ||
|
|
a93cbc0461 | ||
|
|
4736819dd9 | ||
|
|
b5e23b84ce | ||
|
|
5a4afcf8fd | ||
|
|
b407cf0189 | ||
|
|
74c0afe5f5 | ||
|
|
8b7fe6add7 | ||
|
|
ab814e3eac | ||
|
|
a0e1eeb3f1 | ||
|
|
b1ebeb67a7 | ||
|
|
082179f70f | ||
|
|
b827bcdd47 | ||
|
|
0b021273bc | ||
|
|
4590f7daa5 | ||
|
|
7d19ca6a03 | ||
|
|
74511c6fdc | ||
|
|
3bcadbe673 | ||
|
|
8786ebdbca | ||
|
|
b49a4eab62 | ||
|
|
873de82d01 | ||
|
|
0a7b59f500 | ||
|
|
c264d9152f | ||
|
|
3bf9d898c0 | ||
|
|
9eb46c297b | ||
|
|
24fbbbb07b | ||
|
|
7551d14256 | ||
|
|
a7f2849e74 | ||
|
|
48db4ab271 | ||
|
|
0957ece92f | ||
|
|
3d08c79c3e | ||
|
|
949bf38d3c | ||
|
|
7bafb7f959 | ||
|
|
2e0a7857f0 | ||
|
|
9735f55ca4 | ||
|
|
4c1f9b949b | ||
|
|
0af0c94dde | ||
|
|
8e4f0640cc | ||
|
|
1f513e3b43 | ||
|
|
aa0841e2a8 | ||
|
|
b6a1562357 | ||
|
|
bee0797401 | ||
|
|
e085f39c13 | ||
|
|
cf6742e9f2 | ||
|
|
b48a7c7cda | ||
|
|
a799493bf7 | ||
|
|
dc21ab5f59 | ||
|
|
e7a575a33c | ||
|
|
ffd3a461f6 | ||
|
|
344844d3e0 | ||
|
|
6e9f82491d | ||
|
|
372b1c3db8 | ||
|
|
58d305dbed | ||
|
|
0360a0416b | ||
|
|
72282b6e8f | ||
|
|
8391884c4e | ||
|
|
b018f2b0a0 | ||
|
|
f3b63e5126 | ||
|
|
b5f25c85a5 | ||
|
|
ab56b4a818 | ||
|
|
6feeef96c8 | ||
|
|
ffc551f003 | ||
|
|
2b64ff1af2 | ||
|
|
1a351d2832 | ||
|
|
5d839497fc | ||
|
|
572abae37a | ||
|
|
cdf3cffb30 | ||
|
|
daf474dd3a | ||
|
|
d380b2f431 | ||
|
|
01daef21e2 | ||
|
|
e8cb4a5aa7 | ||
|
|
e554790a03 | ||
|
|
7222b1f36f | ||
|
|
a29f3107e9 | ||
|
|
43ccac4197 | ||
|
|
784633a0dc | ||
|
|
3f1e4a201e | ||
|
|
0326786cad | ||
|
|
5fb6b51998 | ||
|
|
753234fdfe | ||
|
|
a23bf53d2b | ||
|
|
6087a80455 | ||
|
|
21f7ccf67d | ||
|
|
b3316e3755 | ||
|
|
d9ef302259 | ||
|
|
7434460b5c | ||
|
|
698a94cc3e | ||
|
|
1f4c541c0d | ||
|
|
faefc3492e | ||
|
|
50bbd5fd65 | ||
|
|
c79977ab57 | ||
|
|
4905c26379 | ||
|
|
28c5d3898f | ||
|
|
40e3d4dc7d | ||
|
|
8cf4a0d3ad | ||
|
|
c22ba3f537 | ||
|
|
c200bbb9fc | ||
|
|
66bca831cc | ||
|
|
a2b1a148d3 | ||
|
|
d7905fb5fe | ||
|
|
1705c07967 | ||
|
|
69eb38060d | ||
|
|
ab6e00a6c1 | ||
|
|
a85b7dac33 | ||
|
|
3d3e09d54d | ||
|
|
61ebc756aa | ||
|
|
af68599ce5 | ||
|
|
977188505e | ||
|
|
8aa4db0c77 | ||
|
|
65a3646ce7 | ||
|
|
4bea38042a | ||
|
|
337abc536b | ||
|
|
cc02b78aca | ||
|
|
18f2d24f8e | ||
|
|
0c7b9a462f | ||
|
|
c7b82f2236 | ||
|
|
37c28c5edb | ||
|
|
6af9aeb345 | ||
|
|
4dd5580854 | ||
|
|
440bd825d8 | ||
|
|
d2379c38bd | ||
|
|
cbc55c577b | ||
|
|
946f0b00e4 | ||
|
|
8e962d15d1 | ||
|
|
b1a8db77db | ||
|
|
7193333b75 | ||
|
|
8a0f14fde4 | ||
|
|
106f0fba0b | ||
|
|
cb73335599 | ||
|
|
f4fa57dac9 | ||
|
|
b07c766551 | ||
|
|
9e3dd69277 | ||
|
|
db9e5665c2 | ||
|
|
df1859f6b8 | ||
|
|
b6b1140a21 | ||
|
|
cad77ce0bf | ||
|
|
6f4518ebf7 | ||
|
|
a8f5748dee | ||
|
|
738d3001be | ||
|
|
a12350f0a0 | ||
|
|
8555dd8154 | ||
|
|
06f364f2c8 | ||
|
|
979d35d87f | ||
|
|
7ca06931ec | ||
|
|
df4e32aaa0 | ||
|
|
a25e37a96d | ||
|
|
f156b46705 | ||
|
|
52b708aaad | ||
|
|
7275485bc1 | ||
|
|
6405228f3f | ||
|
|
3b64e118d0 | ||
|
|
566cd20849 | ||
|
|
fa1b497e0f | ||
|
|
23d073f16c | ||
|
|
40af17fdac | ||
|
|
f4d62baaca | ||
|
|
584921081b | ||
|
|
8a68c2877b | ||
|
|
df76527f29 | ||
|
|
c12b1f4abd | ||
|
|
53a80a5dbe | ||
|
|
ca46718a88 | ||
|
|
d14413f3b0 | ||
|
|
4cb42499ae | ||
|
|
3c6035490d | ||
|
|
4a9fe55976 | ||
|
|
7d91f4783b | ||
|
|
5c6a2af448 | ||
|
|
4fd968270c | ||
|
|
1507792a0c | ||
|
|
00b9bbff75 | ||
|
|
e1f8b4b387 | ||
|
|
708a7dd362 | ||
|
|
1539d86f7d | ||
|
|
cd85b75312 | ||
|
|
d685da377e | ||
|
|
8583992d23 | ||
|
|
7877706b5f | ||
|
|
67bb14d3ee | ||
|
|
5653309080 | ||
|
|
0f52b34b61 | ||
|
|
75e35857c1 | ||
|
|
7a54607219 | ||
|
|
e3f3cac10f | ||
|
|
8090342f60 | ||
|
|
0e60b30d23 | ||
|
|
651814b78e | ||
|
|
8ee6ae5a25 | ||
|
|
54fc8f4390 | ||
|
|
4e5405dce2 | ||
|
|
e54131e95b | ||
|
|
23fec75c90 | ||
|
|
d98f375926 | ||
|
|
ebe7303894 | ||
|
|
79fb977f10 | ||
|
|
d5a7a537e5 | ||
|
|
c0af3414a3 | ||
|
|
4f81be70e3 | ||
|
|
0a6da0bf2f | ||
|
|
1d4d627d05 | ||
|
|
341ee17042 | ||
|
|
7ac1629797 | ||
|
|
ed97dce06e | ||
|
|
2357234f39 | ||
|
|
9c6d059227 | ||
|
|
064075ab5f | ||
|
|
1857d37fae | ||
|
|
60fdbb56a9 | ||
|
|
a3f7d8f996 | ||
|
|
4c7853164d | ||
|
|
56f12e70c1 | ||
|
|
1f0cbfbdc4 | ||
|
|
b14afda160 | ||
|
|
1a699cb52d | ||
|
|
b211147ff0 | ||
|
|
2096aa61e8 | ||
|
|
44b4948972 | ||
|
|
7a6bf12453 | ||
|
|
b95f4822c3 | ||
|
|
ae2bd7d116 | ||
|
|
6d2673ca3f | ||
|
|
d1f34cc44c | ||
|
|
4a05b9ab90 | ||
|
|
d5bb050567 | ||
|
|
ce728bcff4 | ||
|
|
572619017a | ||
|
|
0a6cc130b2 | ||
|
|
af0e109f00 | ||
|
|
0d73133982 | ||
|
|
ec064ccdd8 | ||
|
|
cc5291f6af | ||
|
|
7ac496af7b | ||
|
|
04e69be990 | ||
|
|
6c7a3ce4bb | ||
|
|
b4d412fead | ||
|
|
a9e74b21f1 | ||
|
|
150250454e | ||
|
|
a538f80e95 | ||
|
|
e6730f7164 | ||
|
|
3344723393 | ||
|
|
c571185a91 | ||
|
|
325c1cfa41 | ||
|
|
1069421753 | ||
|
|
b33a97ea5b | ||
|
|
d2c1d4c337 | ||
|
|
67762cf1d8 | ||
|
|
eadce0287c | ||
|
|
487eac3b91 | ||
|
|
84b2913cd9 | ||
|
|
7fc822379d | ||
|
|
176d810c8d | ||
|
|
b6d9360f72 | ||
|
|
9fc2a0a3a1 | ||
|
|
ecaff5b63f | ||
|
|
44dcad0d24 | ||
|
|
a300c9ef96 | ||
|
|
44fe71e4db | ||
|
|
0ac32188c5 | ||
|
|
9aaace706b | ||
|
|
b22de5a824 | ||
|
|
97463661c1 | ||
|
|
239a11855a | ||
|
|
0632557d91 | ||
|
|
44be7d4c51 | ||
|
|
efb4a9d327 | ||
|
|
a077a3f609 | ||
|
|
0d8bc70601 | ||
|
|
7eab5b794d | ||
|
|
2bab4e7164 | ||
|
|
3ccec0aab0 | ||
|
|
ca0a149f06 | ||
|
|
db83b54a88 | ||
|
|
f4567fbf9e | ||
|
|
8fd088754a | ||
|
|
05751ac492 | ||
|
|
e7d63a9fa3 | ||
|
|
3006133f0e | ||
|
|
a1e3a72274 | ||
|
|
c5566d707a | ||
|
|
79beb25530 | ||
|
|
b47b228164 | ||
|
|
be91db14d9 | ||
|
|
9e66564526 | ||
|
|
120893209e | ||
|
|
1a4600ce77 | ||
|
|
781a9a56cd | ||
|
|
27e8a076ff | ||
|
|
f19630bcf5 | ||
|
|
9d93fda471 | ||
|
|
d986659add | ||
|
|
00dab7ca5f | ||
|
|
a4add403fb | ||
|
|
e9cdc96c74 | ||
|
|
6af1fea232 | ||
|
|
45d5d9e44f | ||
|
|
376a084aca | ||
|
|
d1f42d47fe | ||
|
|
64b8fd87ad | ||
|
|
364be48248 | ||
|
|
2bce046278 | ||
|
|
1120d552b6 | ||
|
|
4055fcd891 | ||
|
|
857d82109d | ||
|
|
bf20f3aa8b | ||
|
|
048a2c7979 | ||
|
|
d0b99170f6 | ||
|
|
1d16528dff | ||
|
|
c0ed353c10 | ||
|
|
93be1219eb | ||
|
|
3276d6429d | ||
|
|
b786fdf484 | ||
|
|
50072a63ae | ||
|
|
69cab0817f | ||
|
|
1ab7e1cba8 | ||
|
|
26b63693aa | ||
|
|
5883dc8876 | ||
|
|
3e14077e8d | ||
|
|
3a7f729176 | ||
|
|
fa11f0e488 | ||
|
|
8673234f12 | ||
|
|
ac3bd5161c | ||
|
|
8c55f99740 | ||
|
|
6e0aa7766c | ||
|
|
c4d03bf378 | ||
|
|
61d9428064 | ||
|
|
f6038a4557 | ||
|
|
b0aef35c63 | ||
|
|
c90e564d99 | ||
|
|
6c039be2ca | ||
|
|
ac351b700c | ||
|
|
d1e5d30ea9 | ||
|
|
c73e84d992 | ||
|
|
832dabc8a4 | ||
|
|
c09d205776 | ||
|
|
468ec438ab | ||
|
|
eba703083e | ||
|
|
560b026523 | ||
|
|
b4f9698289 | ||
|
|
c9a2dc0b13 | ||
|
|
b6114266af | ||
|
|
33b576b9d5 | ||
|
|
1da2028d9d | ||
|
|
63de2cc7a0 | ||
|
|
c9c3aa0810 | ||
|
|
841b7fa7ce | ||
|
|
dee4399060 | ||
|
|
7c3f6dcc8d | ||
|
|
1472884eb5 | ||
|
|
d3ea98037e | ||
|
|
2c408445ff | ||
|
|
8d2b5c5464 | ||
|
|
97b5d4bba1 | ||
|
|
68c7e43d8c | ||
|
|
e0af930acd | ||
|
|
dfc8bc4aec | ||
|
|
63b4bca7d8 | ||
|
|
517f8aafdc | ||
|
|
d05ba90779 | ||
|
|
b6620c1f42 | ||
|
|
b60e1f4222 | ||
|
|
4df606f439 | ||
|
|
2a5a497f15 | ||
|
|
3f1da39aee | ||
|
|
740f970041 | ||
|
|
ec22b1c706 | ||
|
|
a1712df7c2 | ||
|
|
a40e11cb3e | ||
|
|
61c46bea40 | ||
|
|
1c5c28a82c | ||
|
|
2310145937 | ||
|
|
6a9c9cadd0 | ||
|
|
a52edc6cc1 | ||
|
|
7774ff9944 | ||
|
|
c367f80ec5 | ||
|
|
09998612e7 | ||
|
|
f71ad55d58 | ||
|
|
5b81397054 | ||
|
|
da353a42da | ||
|
|
e056e0835a | ||
|
|
e1819fb7e5 | ||
|
|
9b0f172f91 | ||
|
|
85dfc013ea | ||
|
|
5a1fae1171 | ||
|
|
33d4c95470 | ||
|
|
659cbc05a9 | ||
|
|
6ce65de2cd | ||
|
|
93b2eb3ff6 | ||
|
|
bf71300635 | ||
|
|
37ecd4a0bc | ||
|
|
827a1b181b | ||
|
|
c4e7cb75cd | ||
|
|
98e4bfcda8 | ||
|
|
ee48ca7671 | ||
|
|
4ba6de1116 | ||
|
|
bfbe636555 | ||
|
|
930fdc8fb4 | ||
|
|
fc90a8fb32 | ||
|
|
791f33fd0b | ||
|
|
1e0a3b163e | ||
|
|
bb1f1a56a5 | ||
|
|
15be85514d | ||
|
|
80cb85b845 | ||
|
|
0b131f1a8c | ||
|
|
0360f0b33b | ||
|
|
e51f2d68cb | ||
|
|
a6d4bf3399 | ||
|
|
113aa4ae08 | ||
|
|
5b40bf6d4e | ||
|
|
06ad8efd89 | ||
|
|
8513afbcf6 | ||
|
|
54ae43ef47 | ||
|
|
7a74b5ee3e | ||
|
|
560fe8a0f6 | ||
|
|
da27d261b0 | ||
|
|
c3e3a18ab4 | ||
|
|
ab34cea714 | ||
|
|
db0780cfa8 | ||
|
|
2b51fc23d9 | ||
|
|
0641773395 | ||
|
|
d745c2e8e3 | ||
|
|
e974c696f7 | ||
|
|
2ff280c4bf | ||
|
|
0e9d43d605 | ||
|
|
cc54363c27 | ||
|
|
c41140d654 | ||
|
|
89affe3139 | ||
|
|
f7fd065bee | ||
|
|
96c7c86e9d | ||
|
|
2c4977dbb1 | ||
|
|
e240175116 | ||
|
|
2398ed6fe8 | ||
|
|
a8420ac33c | ||
|
|
8470be6411 | ||
|
|
3d6295c622 | ||
|
|
ff2f7206f3 | ||
|
|
b937fc8978 | ||
|
|
5f0bd5119a | ||
|
|
86a9a51952 | ||
|
|
4188c9a1dd | ||
|
|
8833fee232 | ||
|
|
5bf642c3f9 | ||
|
|
8c00f89e36 | ||
|
|
9e8ac5c96b | ||
|
|
3d7d4182a6 | ||
|
|
75c221038d | ||
|
|
b7b5b0b8d0 | ||
|
|
05a67f4716 | ||
|
|
6eab6a675c | ||
|
|
f49476a206 | ||
|
|
c1e9c56e25 | ||
|
|
d5dd73cacf | ||
|
|
21f7a49b4e | ||
|
|
716ac04e13 | ||
|
|
c28a32fc47 | ||
|
|
31cba28e8a | ||
|
|
48cd7e6481 | ||
|
|
47aba1c9f9 | ||
|
|
d94e598a89 | ||
|
|
0f3f8bc0d9 | ||
|
|
e0df12c212 | ||
|
|
eb448d9bb8 | ||
|
|
0ba77f13db | ||
|
|
f0a2eb843c | ||
|
|
28acb70118 | ||
|
|
7c35aaa99d | ||
|
|
82fcf1da64 | ||
|
|
c662b95b80 | ||
|
|
a8c2a300f6 | ||
|
|
d654d9d8b1 | ||
|
|
eedc0ca6ea | ||
|
|
e4c3213978 | ||
|
|
520bc55da5 | ||
|
|
ebbb82178f | ||
|
|
d60a9ac63a | ||
|
|
f1486224e9 | ||
|
|
88b02d5a07 | ||
|
|
394b7d09b8 | ||
|
|
8353352bda | ||
|
|
e9313b9c1b | ||
|
|
5cf3d9e4d9 | ||
|
|
41958f55cd | ||
|
|
600ad232e1 | ||
|
|
7a3825cfce | ||
|
|
9519653422 | ||
|
|
efa2307c73 | ||
|
|
068fa3d0e3 | ||
|
|
13d8dbd542 | ||
|
|
e59cc3311d | ||
|
|
4f4e94f753 | ||
|
|
258970f489 | ||
|
|
781abe5f8f | ||
|
|
73845cbec5 | ||
|
|
b442ba8b2b | ||
|
|
c2f94e9e8a | ||
|
|
10e36d2355 | ||
|
|
13c53fedad | ||
|
|
4bda1bd884 | ||
|
|
3abe7850d6 | ||
|
|
b50284d864 | ||
|
|
81c6e52401 | ||
|
|
e54efda36f | ||
|
|
847d257366 | ||
|
|
687662cf1f | ||
|
|
6432d98469 | ||
|
|
088ccf8b8d | ||
|
|
e8683bf957 | ||
|
|
4653981b6b | ||
|
|
e2547413d3 | ||
|
|
ea17f41b5b | ||
|
|
29178d8adf | ||
|
|
d4bd19f6d8 | ||
|
|
7e86ead574 | ||
|
|
72debcb228 | ||
|
|
72737dabc7 | ||
|
|
4decbbbf18 | ||
|
|
f6e5cb4381 | ||
|
|
b15867f92e | ||
|
|
a5e5fbc6e0 | ||
|
|
1b1471b6d8 | ||
|
|
ffad3b5fb1 | ||
|
|
cba9fc3020 | ||
|
|
5280bffde2 | ||
|
|
e776accaf3 | ||
|
|
3592240d14 | ||
|
|
3eac26929a | ||
|
|
4d3adec738 | ||
|
|
ac5dd1f45a | ||
|
|
3005cf3282 | ||
|
|
54b272206e | ||
|
|
db0fc94b39 | ||
|
|
685f199f91 | ||
|
|
89bed479e4 | ||
|
|
5547247aa9 | ||
|
|
faa2d00cc6 | ||
|
|
fb8a356616 | ||
|
|
0c3aa8f5ec | ||
|
|
e2fd3f2983 | ||
|
|
f137af4ec5 | ||
|
|
fdd673a3a9 | ||
|
|
aed9955105 | ||
|
|
22f6d285c7 | ||
|
|
10aa16b471 | ||
|
|
3d761a3189 | ||
|
|
e3903f34e4 | ||
|
|
f4f055fb36 | ||
|
|
b3838581fd | ||
|
|
affbe7ccdb | ||
|
|
8563ae5511 | ||
|
|
2c765ccfae | ||
|
|
626e7b2211 | ||
|
|
516b6b0fa8 | ||
|
|
613d086f1e | ||
|
|
9e0630f012 | ||
|
|
d6d9554954 | ||
|
|
dd8577f832 | ||
|
|
2a532ab729 | ||
|
|
03eef65b25 | ||
|
|
ad07d63994 | ||
|
|
8685f055ea | ||
|
|
3b868a1cec | ||
|
|
ab389eaa8e | ||
|
|
008f778e8f | ||
|
|
d7f5da5df4 | ||
|
|
9fda130b3a | ||
|
|
72cdbdba0f | ||
|
|
b92a153902 | ||
|
|
9f2927979b | ||
|
|
75257232c3 | ||
|
|
1721314c62 | ||
|
|
fc230bcc59 | ||
|
|
b4636ddf44 | ||
|
|
f16151ea29 | ||
|
|
b1140301a4 | ||
|
|
58cd785da6 | ||
|
|
2035186cd2 | ||
|
|
53ba6aadff | ||
|
|
aa44c38b58 | ||
|
|
f091868b7c | ||
|
|
89bedae0d3 | ||
|
|
c8acc48976 | ||
|
|
21fee59b22 | ||
|
|
957a8253f8 | ||
|
|
d5fc3e7bed | ||
|
|
ab438b42da | ||
|
|
3867fece4a | ||
|
|
2b908d4fbe | ||
|
|
8ff062ec8b | ||
|
|
294fc41aec | ||
|
|
684f7df158 | ||
|
|
c3287755e3 | ||
|
|
9f97f4d79e | ||
|
|
34eb421649 | ||
|
|
850b05573e | ||
|
|
6ec8bfdfee | ||
|
|
81638c248e | ||
|
|
2e11b1298e | ||
|
|
20320f3a27 | ||
|
|
4019c12d26 | ||
|
|
cf72184ce4 | ||
|
|
ca8d15bc64 | ||
|
|
a91c897fd3 | ||
|
|
816bdf0320 | ||
|
|
d4a6acbd99 | ||
|
|
e421db4005 | ||
|
|
6af168cb31 | ||
|
|
29f56cf0cf | ||
|
|
11b6ea742d | ||
|
|
05d231ad33 | ||
|
|
48f3c69c69 | ||
|
|
9067c2a9c1 | ||
|
|
8b68020453 | ||
|
|
9f7321ca1a | ||
|
|
5fa01132b9 | ||
|
|
4d2fc66a8d | ||
|
|
f72ed4898c | ||
|
|
e082b6d599 | ||
|
|
d44be2d835 | ||
|
|
85a73181cc | ||
|
|
e31e4ab677 | ||
|
|
0d95c2192e | ||
|
|
7dc8557033 | ||
|
|
1fa8b26e55 | ||
|
|
4b085d46f6 | ||
|
|
72037a1865 | ||
|
|
635c4ed4ce | ||
|
|
7ffcf8dd6f | ||
|
|
97cd21d3be | ||
|
|
a13cb7e1c5 | ||
|
|
7b602e9003 | ||
|
|
5a26ebec8f | ||
|
|
8341b8b1c1 | ||
|
|
bbb640c9a2 | ||
|
|
0c97bbf137 | ||
|
|
45fddc70d5 | ||
|
|
f977dc410a | ||
|
|
d535818505 | ||
|
|
fcf4e1f37d | ||
|
|
38130c8502 | ||
|
|
f284c91988 | ||
|
|
584b2cefa3 | ||
|
|
42091b4a79 | ||
|
|
2d1621c43d | ||
|
|
d1a5db3310 | ||
|
|
ad8fd8fecc | ||
|
|
be74b76079 | ||
|
|
dd64af728f | ||
|
|
e43b46786d | ||
|
|
3f3b37b843 | ||
|
|
2ecf9f6ddf | ||
|
|
48c069fe68 | ||
|
|
9c5c597c85 | ||
|
|
c2eec8545d | ||
|
|
2395d4be26 | ||
|
|
9455476705 | ||
|
|
494e223706 | ||
|
|
348fd18230 | ||
|
|
7233b4de55 | ||
|
|
af6df05685 | ||
|
|
965b65db6e | ||
|
|
4cc01c8aa8 | ||
|
|
41372168b6 | ||
|
|
f4438b0a08 | ||
|
|
897c842637 | ||
|
|
ee86ceb906 | ||
|
|
e298732499 | ||
|
|
4081937e22 | ||
|
|
f9aedb2118 | ||
|
|
74b4719af8 | ||
|
|
2f35cc9188 | ||
|
|
2f966d8c38 | ||
|
|
b0868d9136 | ||
|
|
37440e9416 | ||
|
|
0d7d27ec0b |
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -2,8 +2,6 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -6,9 +6,6 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# *db files
|
||||
*.db
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
@@ -238,7 +235,4 @@ scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
.serena/
|
||||
@@ -117,7 +117,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
@@ -27,12 +27,12 @@ FILES_URL=http://localhost:5001
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
@@ -469,9 +469,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@@ -527,7 +524,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
# Workflow log cleanup configuration
|
||||
# Enable automatic cleanup of workflow run logs to manage database size
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=true
|
||||
# Number of days to retain workflow run logs (default: 30 days)
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
@@ -549,12 +546,6 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
@@ -627,8 +618,5 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# Agent Skill Index
|
||||
|
||||
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Platform Foundations
|
||||
|
||||
- **[Infrastructure Overview](agent_skills/infra.md)**\
|
||||
When to read this:
|
||||
|
||||
- You need to understand where a feature belongs in the architecture.
|
||||
- You’re wiring storage, Redis, vector stores, or OTEL.
|
||||
- You’re about to add CLI commands or async jobs.\
|
||||
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
|
||||
|
||||
- **[Coding Style](agent_skills/coding_style.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re writing or reviewing backend code and need the authoritative checklist.
|
||||
- You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
|
||||
- You want the exact lint/type/test commands used in PRs.\
|
||||
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Plugin & Extension Development
|
||||
|
||||
- **[Plugin Systems](agent_skills/plugin.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re building or debugging a marketplace plugin.
|
||||
- You need to know how manifests, providers, daemons, and migrations fit together.\
|
||||
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
|
||||
|
||||
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
|
||||
When to read this:
|
||||
|
||||
- You must integrate OAuth for a plugin or datasource.
|
||||
- You’re handling credential encryption or refresh flows.\
|
||||
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Workflow Entry & Execution
|
||||
|
||||
- **[Trigger Concepts](agent_skills/trigger.md)**\
|
||||
When to read this:
|
||||
- You’re debugging why a workflow didn’t start.
|
||||
- You’re adding a new trigger type or hook.
|
||||
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
|
||||
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Additional Notes for Agents
|
||||
|
||||
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
|
||||
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
|
||||
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
|
||||
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
|
||||
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
|
||||
@@ -1,115 +0,0 @@
|
||||
## Linter
|
||||
|
||||
- Always follow `.ruff.toml`.
|
||||
- Run `uv run ruff check --fix --unsafe-fixes`.
|
||||
- Keep each line under 100 characters (including spaces).
|
||||
|
||||
## Code Style
|
||||
|
||||
- `snake_case` for variables and functions.
|
||||
- `PascalCase` for classes.
|
||||
- `UPPER_CASE` for constants.
|
||||
|
||||
## Rules
|
||||
|
||||
- Use Pydantic v2 standard.
|
||||
- Use `uv` for package management.
|
||||
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
|
||||
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
|
||||
- Prefer simple functions over classes for lightweight helpers.
|
||||
- Keep files below 800 lines; split when necessary.
|
||||
- Keep code readable—no clever hacks.
|
||||
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
- Mirror the project’s layered architecture: controller → service → core/domain.
|
||||
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
|
||||
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
|
||||
|
||||
## SQLAlchemy Patterns
|
||||
|
||||
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
|
||||
|
||||
- Open sessions with context managers:
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
|
||||
|
||||
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
|
||||
|
||||
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
|
||||
|
||||
## Storage & External IO
|
||||
|
||||
- Access storage via `extensions.ext_storage.storage`.
|
||||
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
|
||||
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
|
||||
|
||||
## Pydantic Usage
|
||||
|
||||
- Define DTOs with Pydantic v2 models and forbid extras by default.
|
||||
|
||||
- Use `@field_validator` / `@model_validator` for domain rules.
|
||||
|
||||
- Example:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
endpoint: HttpUrl
|
||||
secret: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("secret")
|
||||
def ensure_secret_prefix(cls, value: str) -> str:
|
||||
if not value.startswith("dify_"):
|
||||
raise ValueError("secret must start with dify_")
|
||||
return value
|
||||
```
|
||||
|
||||
## Generics & Protocols
|
||||
|
||||
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
|
||||
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
|
||||
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
|
||||
|
||||
## Error Handling & Logging
|
||||
|
||||
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
|
||||
- Declare `logger = logging.getLogger(__name__)` at module top.
|
||||
- Include tenant/app/workflow identifiers in log context.
|
||||
- Log retryable events at `warning`, terminal failures at `error`.
|
||||
|
||||
## Tooling & Checks
|
||||
|
||||
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
|
||||
- Type checks: `uv run --directory api --dev basedpyright`.
|
||||
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
- Run all of the above before submitting your work.
|
||||
|
||||
## Controllers & Services
|
||||
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
|
||||
- Document non-obvious behaviour with concise comments.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
- Use `configs.dify_config` for configuration—never read environment variables directly.
|
||||
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
|
||||
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
|
||||
- Keep experimental scripts under `dev/`; do not ship them in production builds.
|
||||
@@ -1,96 +0,0 @@
|
||||
## Configuration
|
||||
|
||||
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
|
||||
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
|
||||
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
|
||||
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
|
||||
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
|
||||
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
|
||||
|
||||
## Storage & Files
|
||||
|
||||
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
|
||||
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
|
||||
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
|
||||
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
|
||||
|
||||
## Redis & Shared State
|
||||
|
||||
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
|
||||
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
|
||||
|
||||
## Models
|
||||
|
||||
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
|
||||
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
|
||||
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
|
||||
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
|
||||
|
||||
## Vector Stores
|
||||
|
||||
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
|
||||
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
|
||||
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
|
||||
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
|
||||
|
||||
## Observability & OTEL
|
||||
|
||||
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
|
||||
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
|
||||
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
|
||||
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
|
||||
|
||||
## Ops Integrations
|
||||
|
||||
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
|
||||
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
|
||||
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
|
||||
|
||||
## Controllers, Services, Core
|
||||
|
||||
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
|
||||
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
|
||||
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
|
||||
|
||||
## Plugins, Tools, Providers
|
||||
|
||||
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
|
||||
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
|
||||
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
|
||||
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
|
||||
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
|
||||
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
|
||||
|
||||
## Async Workloads
|
||||
|
||||
see `agent_skills/trigger.md` for more detailed documentation.
|
||||
|
||||
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
|
||||
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
|
||||
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
|
||||
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
|
||||
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
|
||||
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
|
||||
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
|
||||
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
|
||||
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
|
||||
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
|
||||
|
||||
## When You Add Features
|
||||
|
||||
- Check for an existing helper or service before writing a new util.
|
||||
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
|
||||
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
|
||||
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
|
||||
@@ -1 +0,0 @@
|
||||
// TBD
|
||||
@@ -1 +0,0 @@
|
||||
// TBD
|
||||
@@ -1,53 +0,0 @@
|
||||
## Overview
|
||||
|
||||
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
|
||||
|
||||
## Trigger nodes
|
||||
|
||||
- `UserInput`
|
||||
- `Trigger Webhook`
|
||||
- `Trigger Schedule`
|
||||
- `Trigger Plugin`
|
||||
|
||||
### UserInput
|
||||
|
||||
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
|
||||
|
||||
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
|
||||
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
|
||||
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
|
||||
|
||||
### Trigger Webhook
|
||||
|
||||
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
|
||||
|
||||
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
|
||||
|
||||
### Trigger Schedule
|
||||
|
||||
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
|
||||
|
||||
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
|
||||
|
||||
### Trigger Plugin
|
||||
|
||||
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
|
||||
|
||||
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
|
||||
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
|
||||
|
||||
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
|
||||
|
||||
## Worker Pool / Async Task
|
||||
|
||||
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
|
||||
|
||||
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
|
||||
|
||||
## Debug Strategy
|
||||
|
||||
Dify divided users into 2 groups: builders / end users.
|
||||
|
||||
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
|
||||
|
||||
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
|
||||
22
api/app.py
22
api/app.py
@@ -1,17 +1,24 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command() -> bool:
|
||||
def is_db_command():
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# create app
|
||||
celery = None
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@@ -22,8 +29,15 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = flask_app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
||||
@@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> DifyApp:
|
||||
def create_app() -> tuple[any, DifyApp]:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
import socketio
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return app
|
||||
return socketio_app, app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
@@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@@ -1229,55 +1229,6 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@@ -1471,10 +1422,7 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
def transform_datasource_credentials():
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
@@ -1485,14 +1433,9 @@ def transform_datasource_credentials(environment: str):
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
|
||||
@@ -174,33 +174,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for async workflow
|
||||
"""
|
||||
|
||||
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
|
||||
description="Granularity for async workflow scheduler, "
|
||||
"sometime, few users could block the queue due to some time-consuming tasks, "
|
||||
"to avoid this, workflow can be suspended if needed, to achieve"
|
||||
"this, a time-based checker is required, every granularity seconds, "
|
||||
"the checker will check the workflow queue and suspend the workflow",
|
||||
default=120,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@@ -290,8 +263,6 @@ class EndpointConfig(BaseSettings):
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@@ -900,6 +871,16 @@ class MailConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
@@ -1054,44 +1035,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
default=True,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
|
||||
description="Trigger provider refresh poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
|
||||
description="Max trigger subscriptions to process per tick",
|
||||
default=200,
|
||||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=180,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@@ -1150,6 +1093,13 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1190,7 +1140,7 @@ class AccountConfig(BaseSettings):
|
||||
|
||||
|
||||
class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
@@ -1209,9 +1159,9 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
class TenantSelfTaskQueueConfig(BaseSettings):
|
||||
TENANT_SELF_TASK_QUEUE_PULL_SIZE: int = Field(
|
||||
description="Default batch size for tenant self task queue pull operations",
|
||||
default=1,
|
||||
)
|
||||
|
||||
@@ -1222,8 +1172,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
@@ -1242,12 +1190,13 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
TenantSelfTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
HOSTED_POOL_CREDITS: int = Field(
|
||||
description="Pool credits for hosted service",
|
||||
default=200,
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
@@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-4-turbo,"
|
||||
"gpt-4.1,"
|
||||
"gpt-4.1-2025-04-14,"
|
||||
"gpt-4.1-mini,"
|
||||
"gpt-4.1-mini-2025-04-14,"
|
||||
"gpt-4.1-nano,"
|
||||
"gpt-4.1-nano-2025-04-14,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted OpenAI service usage",
|
||||
default=200,
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003,"
|
||||
"chatgpt-4o-latest,"
|
||||
"gpt-4o,"
|
||||
"gpt-4o-2024-05-13,"
|
||||
"gpt-4o-2024-08-06,"
|
||||
"gpt-4o-2024-11-20,"
|
||||
"gpt-4o-audio-preview,"
|
||||
"gpt-4o-audio-preview-2025-06-03,"
|
||||
"gpt-4o-mini,"
|
||||
"gpt-4o-mini-2024-07-18,"
|
||||
"o3-mini,"
|
||||
"o3-mini-2025-01-31,"
|
||||
"gpt-5-mini-2025-08-07,"
|
||||
"gpt-5-mini,"
|
||||
"o4-mini,"
|
||||
"o4-mini-2025-04-16,"
|
||||
"gpt-5-chat-latest,"
|
||||
"gpt-5,"
|
||||
"gpt-5-2025-08-07,"
|
||||
"gpt-5-nano,"
|
||||
"gpt-5-nano-2025-08-07",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
@@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-4-turbo,"
|
||||
"gpt-4.1,"
|
||||
"gpt-4.1-2025-04-14,"
|
||||
"gpt-4.1-mini,"
|
||||
"gpt-4.1-mini-2025-04-14,"
|
||||
"gpt-4.1-nano,"
|
||||
"gpt-4.1-nano-2025-04-14,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
@@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003",
|
||||
"text-davinci-003,"
|
||||
"chatgpt-4o-latest,"
|
||||
"gpt-4o,"
|
||||
"gpt-4o-2024-05-13,"
|
||||
"gpt-4o-2024-08-06,"
|
||||
"gpt-4o-2024-11-20,"
|
||||
"gpt-4o-audio-preview,"
|
||||
"gpt-4o-audio-preview-2025-06-03,"
|
||||
"gpt-4o-mini,"
|
||||
"gpt-4o-mini-2024-07-18,"
|
||||
"o3-mini,"
|
||||
"o3-mini-2025-01-31,"
|
||||
"gpt-5-mini-2025-08-07,"
|
||||
"gpt-5-mini,"
|
||||
"o4-mini,"
|
||||
"o4-mini-2025-04-16,"
|
||||
"gpt-5-chat-latest,"
|
||||
"gpt-5,"
|
||||
"gpt-5-2025-08-07,"
|
||||
"gpt-5-nano,"
|
||||
"gpt-5-nano-2025-08-07",
|
||||
)
|
||||
|
||||
|
||||
class HostedGeminiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Gemini service
|
||||
"""
|
||||
|
||||
HOSTED_GEMINI_API_KEY: str | None = Field(
|
||||
description="API key for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Gemini API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
|
||||
class HostedXAIConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching XAI service
|
||||
"""
|
||||
|
||||
HOSTED_XAI_API_KEY: str | None = Field(
|
||||
description="API key for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted XAI API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedDeepseekConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Deepseek service
|
||||
"""
|
||||
|
||||
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
|
||||
description="API key for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Deepseek API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Deepseek service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
|
||||
@@ -144,16 +326,30 @@ class HostedAnthropicConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted Anthropic service usage",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
@@ -250,5 +446,8 @@ class HostedServiceConfig(
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
HostedGeminiConfig,
|
||||
HostedXAIConfig,
|
||||
HostedDeepseekConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
|
||||
"""
|
||||
@@ -42,11 +41,3 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@@ -58,15 +58,16 @@ from .app import (
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
online_user,
|
||||
ops_trace,
|
||||
site,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@@ -107,10 +108,12 @@ from .datasets.rag_pipeline import (
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
@@ -127,7 +130,6 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@@ -145,6 +147,7 @@ __all__ = [
|
||||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
@@ -198,7 +201,7 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"trial",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
@@ -206,6 +209,5 @@ __all__ = [
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
@@ -16,7 +16,7 @@ from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@@ -52,6 +52,8 @@ class InsertExploreAppListApi(Resource):
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
"can_trial": fields.Boolean(required=True, description="Can trial"),
|
||||
"trial_limit": fields.Integer(required=True, description="Trial limit"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -71,6 +73,8 @@ class InsertExploreAppListApi(Resource):
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
|
||||
.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -108,6 +112,20 @@ class InsertExploreAppListApi(Resource):
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
@@ -122,6 +140,20 @@ class InsertExploreAppListApi(Resource):
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
@@ -167,7 +199,83 @@ class InsertExploreAppApi(Resource):
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBanner(Resource):
|
||||
@api.doc("insert_explore_banner")
|
||||
@api.doc(description="Insert an explore banner")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InsertExploreBannerRequest",
|
||||
{
|
||||
"content": fields.String(required=True, description="Banner content"),
|
||||
"link": fields.String(required=True, description="Banner link"),
|
||||
"sort": fields.Integer(required=True, description="Banner sort"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Banner inserted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("title", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("img-src", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
parser.add_argument("language", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
content = {
|
||||
"category": args["category"],
|
||||
"title": args["title"],
|
||||
"description": args["description"],
|
||||
"img-src": args["img-src"],
|
||||
}
|
||||
|
||||
if not args["language"]:
|
||||
args["language"] = "en-US"
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
link=args["link"],
|
||||
sort=args["sort"],
|
||||
language=args["language"],
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBanner(Resource):
|
||||
@api.doc("delete_explore_banner")
|
||||
@api.doc(description="Delete an explore banner")
|
||||
@api.response(204, "Banner deleted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
@@ -11,7 +11,6 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@@ -207,11 +206,13 @@ class InstructionGenerateApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
|
||||
339
api/controllers/console/app/online_user.py
Normal file
339
api/controllers/console/app/online_user.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from werkzeug.wrappers import Request as WerkzeugRequest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from services.account_service import AccountService
|
||||
|
||||
SESSION_STATE_TTL_SECONDS = 3600
|
||||
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
|
||||
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
|
||||
WS_SID_MAP_PREFIX = "ws_sid_map:"
|
||||
|
||||
|
||||
def _workflow_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
|
||||
|
||||
|
||||
def _leader_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
|
||||
|
||||
|
||||
def _sid_key(sid: str) -> str:
|
||||
return f"{WS_SID_MAP_PREFIX}{sid}"
|
||||
|
||||
|
||||
def _refresh_session_state(workflow_id: str, sid: str) -> None:
|
||||
"""
|
||||
Refresh TTLs for workflow + session keys so healthy sessions do not linger forever after crashes.
|
||||
"""
|
||||
workflow_key = _workflow_key(workflow_id)
|
||||
sid_key = _sid_key(sid)
|
||||
if redis_client.exists(workflow_key):
|
||||
redis_client.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
|
||||
if redis_client.exists(sid_key):
|
||||
redis_client.expire(sid_key, SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
|
||||
@sio.on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
token = None
|
||||
if auth and isinstance(auth, dict):
|
||||
token = auth.get("token")
|
||||
|
||||
if not token:
|
||||
try:
|
||||
request_environ = WerkzeugRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@sio.on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
session = sio.get_session(sid)
|
||||
user_id = session.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
# Each session is stored independently with sid as key
|
||||
session_info = {
|
||||
"user_id": user_id,
|
||||
"username": session.get("username", "Unknown"),
|
||||
"avatar": session.get("avatar", None),
|
||||
"sid": sid,
|
||||
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
|
||||
}
|
||||
|
||||
workflow_key = _workflow_key(workflow_id)
|
||||
# Store session info with sid as key
|
||||
redis_client.hset(workflow_key, sid, json.dumps(session_info))
|
||||
redis_client.set(
|
||||
_sid_key(sid),
|
||||
json.dumps({"workflow_id": workflow_id, "user_id": user_id}),
|
||||
ex=SESSION_STATE_TTL_SECONDS,
|
||||
)
|
||||
_refresh_session_state(workflow_id, sid)
|
||||
|
||||
# Leader election: first session becomes the leader
|
||||
leader_sid = get_or_set_leader(workflow_id, sid)
|
||||
is_leader = leader_sid == sid
|
||||
|
||||
sio.enter_room(sid, workflow_id)
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
# Notify this session of their leader status
|
||||
sio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@sio.on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
mapping = redis_client.get(_sid_key(sid))
|
||||
if mapping:
|
||||
data = json.loads(mapping)
|
||||
workflow_id = data["workflow_id"]
|
||||
|
||||
# Remove this specific session
|
||||
redis_client.hdel(_workflow_key(workflow_id), sid)
|
||||
redis_client.delete(_sid_key(sid))
|
||||
|
||||
# Handle leader re-election if the leader session disconnected
|
||||
handle_leader_disconnect(workflow_id, sid)
|
||||
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
|
||||
def _clear_session_state(workflow_id: str, sid: str) -> None:
|
||||
redis_client.hdel(_workflow_key(workflow_id), sid)
|
||||
redis_client.delete(_sid_key(sid))
|
||||
|
||||
|
||||
def _is_session_active(workflow_id: str, sid: str) -> bool:
|
||||
if not sid:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not sio.manager.is_connected(sid, "/"):
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if not redis_client.hexists(_workflow_key(workflow_id), sid):
|
||||
return False
|
||||
|
||||
if not redis_client.exists(_sid_key(sid)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_or_set_leader(workflow_id: str, sid: str) -> str:
|
||||
"""
|
||||
Get current leader session or set this session as leader if no valid leader exists.
|
||||
Returns the leader session id (sid).
|
||||
"""
|
||||
raw_leader = redis_client.get(_leader_key(workflow_id))
|
||||
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
|
||||
leader_replaced = False
|
||||
|
||||
if current_leader and not _is_session_active(workflow_id, current_leader):
|
||||
_clear_session_state(workflow_id, current_leader)
|
||||
redis_client.delete(_leader_key(workflow_id))
|
||||
current_leader = None
|
||||
leader_replaced = True
|
||||
|
||||
if not current_leader:
|
||||
redis_client.set(_leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) # Expire in 1 hour
|
||||
if leader_replaced:
|
||||
broadcast_leader_change(workflow_id, sid)
|
||||
return sid
|
||||
|
||||
return current_leader
|
||||
|
||||
|
||||
def handle_leader_disconnect(workflow_id, disconnected_sid):
|
||||
"""
|
||||
Handle leader re-election when a session disconnects.
|
||||
If the disconnected session was the leader, elect a new leader from remaining sessions.
|
||||
"""
|
||||
current_leader = redis_client.get(_leader_key(workflow_id))
|
||||
|
||||
if current_leader:
|
||||
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
|
||||
|
||||
if current_leader == disconnected_sid:
|
||||
# Leader session disconnected, elect a new leader
|
||||
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
|
||||
|
||||
if sessions_json:
|
||||
# Get the first remaining session as new leader
|
||||
new_leader_sid = list(sessions_json.keys())[0]
|
||||
if isinstance(new_leader_sid, bytes):
|
||||
new_leader_sid = new_leader_sid.decode("utf-8")
|
||||
|
||||
redis_client.set(_leader_key(workflow_id), new_leader_sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
# Notify all sessions about the new leader
|
||||
broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
else:
|
||||
# No sessions left, remove leader
|
||||
redis_client.delete(_leader_key(workflow_id))
|
||||
|
||||
|
||||
def broadcast_leader_change(workflow_id, new_leader_sid):
|
||||
"""
|
||||
Broadcast leader change to all sessions in the workflow.
|
||||
"""
|
||||
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
|
||||
|
||||
for sid, session_info_json in sessions_json.items():
|
||||
try:
|
||||
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
|
||||
is_leader = sid_str == new_leader_sid
|
||||
# Emit to each session whether they are the new leader
|
||||
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def get_current_leader(workflow_id):
|
||||
"""
|
||||
Get the current leader for a workflow.
|
||||
"""
|
||||
leader = redis_client.get(_leader_key(workflow_id))
|
||||
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
|
||||
|
||||
|
||||
def broadcast_online_users(workflow_id):
|
||||
"""
|
||||
Broadcast online users to the workflow room.
|
||||
Each session is shown as a separate user (even if same person has multiple tabs).
|
||||
"""
|
||||
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
|
||||
users = []
|
||||
|
||||
for sid, session_info_json in sessions_json.items():
|
||||
try:
|
||||
session_info = json.loads(session_info_json)
|
||||
# Each session appears as a separate "user" in the UI
|
||||
users.append(
|
||||
{
|
||||
"user_id": session_info["user_id"],
|
||||
"username": session_info["username"],
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": session_info["sid"],
|
||||
"connected_at": session_info.get("connected_at"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Sort by connection time to maintain consistent order
|
||||
users.sort(key=lambda x: x.get("connected_at") or 0)
|
||||
|
||||
# Get current leader session
|
||||
leader_sid = get_current_leader(workflow_id)
|
||||
|
||||
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
|
||||
|
||||
|
||||
@sio.on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
|
||||
"""
|
||||
mapping = redis_client.get(_sid_key(sid))
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
user_id = mapping_data["user_id"]
|
||||
_refresh_session_state(workflow_id, sid)
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
if not event_type:
|
||||
return {"msg": "invalid event type"}, 400
|
||||
|
||||
sio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=workflow_id,
|
||||
skip_sid=sid,
|
||||
)
|
||||
|
||||
return {"msg": "event_broadcasted"}
|
||||
|
||||
|
||||
@sio.on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
mapping = redis_client.get(_sid_key(sid))
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
_refresh_session_state(workflow_id, sid)
|
||||
|
||||
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "graph_update_broadcasted"}
|
||||
@@ -5,10 +5,12 @@ from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from pydantic_core import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -16,22 +18,14 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.debug.event_selectors import (
|
||||
TriggerDebugEvent,
|
||||
TriggerDebugEventPoller,
|
||||
create_event_poller,
|
||||
select_trigger_debug_events,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@@ -47,7 +41,6 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@@ -110,6 +103,7 @@ class DraftWorkflowApi(Resource):
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -144,6 +138,8 @@ class DraftWorkflowApi(Resource):
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
.add_argument("force_upload", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("memory_blocks", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
@@ -161,6 +157,8 @@ class DraftWorkflowApi(Resource):
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"memory_blocks": data.get("memory_blocks"),
|
||||
"force_upload": data.get("force_upload", False),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
@@ -177,6 +175,10 @@ class DraftWorkflowApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
memory_blocks_list = args.get("memory_blocks") or []
|
||||
from core.memory.entities import MemoryBlockSpec
|
||||
|
||||
memory_blocks = [MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
@@ -185,9 +187,13 @@ class DraftWorkflowApi(Resource):
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
force_upload=args.get("force_upload", False),
|
||||
memory_blocks=memory_blocks,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValidationError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@@ -756,6 +762,45 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@api.doc("get_workflow_config")
|
||||
@api.doc(description="Get workflow configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Workflow configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("features", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
features = args.get("features")
|
||||
|
||||
# Update draft workflow features
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.doc("get_all_published_workflows")
|
||||
@@ -939,232 +984,103 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run")
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("workflow_ids", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
poller: TriggerDebugEventPoller = create_event_poller(
|
||||
draft_workflow=draft_workflow,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
event: TriggerDebugEvent | None = None
|
||||
try:
|
||||
event = poller.poll()
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=node_id,
|
||||
)
|
||||
)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run")
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
|
||||
if not node_config:
|
||||
raise ValueError("Node data not found for node %s", node_id)
|
||||
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
event: TriggerDebugEvent | None = None
|
||||
# for schedule trigger, when run single node, just execute directly
|
||||
if node_type == NodeType.TRIGGER_SCHEDULE:
|
||||
event = TriggerDebugEvent(
|
||||
workflow_args={},
|
||||
node_id=node_id,
|
||||
)
|
||||
# for other trigger types, poll for the event
|
||||
else:
|
||||
try:
|
||||
poller: TriggerDebugEventPoller = create_event_poller(
|
||||
draft_workflow=draft_workflow,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
event = poller.poll()
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
raw_files = event.workflow_args.get("files")
|
||||
files = _parse_file(draft_workflow, raw_files if isinstance(raw_files, list) else None)
|
||||
try:
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=event.workflow_args.get("inputs") or {},
|
||||
account=current_user,
|
||||
query="",
|
||||
files=files,
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception as e:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{"status": "error", "error": "An unexpected error occurred while running the node."}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
|
||||
class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
|
||||
"""
|
||||
|
||||
@api.doc("draft_workflow_trigger_run_all")
|
||||
@api.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Full workflow debug when the start node is a trigger
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
trigger_debug_event: TriggerDebugEvent | None = select_trigger_debug_events(
|
||||
draft_workflow=draft_workflow,
|
||||
app_model=app_model,
|
||||
user_id=current_user.id,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
if trigger_debug_event is None:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=trigger_debug_event.node_id,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run-all")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowFeaturesApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/features",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowTaskStopApi,
|
||||
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigsApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToWorkflowApi,
|
||||
"/apps/<uuid:app_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeLastRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")
|
||||
|
||||
@@ -28,7 +28,6 @@ class WorkflowAppLogApi(Resource):
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
@@ -69,7 +68,6 @@ class WorkflowAppLogApi(Resource):
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
@@ -94,7 +92,6 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
detail=args.detail,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
240
api/controllers/console/app/workflow_comment.py
Normal file
240
api/controllers/console/app/workflow_comment.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import account_with_role_fields
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_basic_fields, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_create_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("position_x", type=float, required=True, location="json")
|
||||
parser.add_argument("position_y", type=float, required=True, location="json")
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=args.content,
|
||||
position_x=args.position_x,
|
||||
position_y=args.position_y,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_detail_fields)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_update_fields)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("position_x", type=float, required=False, location="json")
|
||||
parser.add_argument("position_y", type=float, required=False, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=args.content,
|
||||
position_x=args.position_x,
|
||||
position_y=args.position_y,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_resolve_fields)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_reply_create_fields)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=args.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_reply_update_fields)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"users": members}
|
||||
|
||||
|
||||
# Register API routes
|
||||
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
|
||||
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
api.add_resource(
|
||||
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
|
||||
)
|
||||
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
@@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
@@ -355,7 +355,7 @@ class VariableApi(Resource):
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
@@ -448,8 +448,35 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_variables", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@api.doc("get_system_variables")
|
||||
@api.doc(description="Get system variables for workflow")
|
||||
@@ -499,3 +526,44 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@@ -23,6 +23,15 @@ def _load_app_model(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
@@ -62,3 +71,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model_with_trial(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
modes = mode
|
||||
else:
|
||||
modes = [mode]
|
||||
|
||||
if app_mode not in modes:
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@@ -25,10 +25,13 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.passport import PassportService
|
||||
from libs.token import (
|
||||
check_csrf_token,
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_access_token,
|
||||
extract_refresh_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
@@ -289,3 +292,18 @@ class RefreshTokenApi(Resource):
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
# this api helps frontend to check whether user is authenticated
|
||||
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
|
||||
@console_ns.route("/login/status")
|
||||
class LoginStatus(Resource):
|
||||
def get(self):
|
||||
token = extract_access_token(request)
|
||||
res = True
|
||||
try:
|
||||
validated = PassportService().verify(token=token)
|
||||
check_csrf_token(request=request, user_id=validated.get("user_id"))
|
||||
except:
|
||||
res = False
|
||||
return {"logged_in": res}
|
||||
|
||||
43
api/controllers/console/explore/banner.py
Normal file
43
api/controllers/console/explore/banner.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"id": banner.id,
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
@@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
||||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
||||
@@ -27,6 +27,7 @@ recommended_app_fields = {
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
||||
514
api/controllers/console/explore/trial.py
Normal file
514
api/controllers/console/explore/trial.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common import fields
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NeedAddIdsError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model_with_trial
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotWorkflowAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
|
||||
tenant_id = app_model.tenant_id
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
|
||||
else:
|
||||
raise NeedAddIdsError()
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
|
||||
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
@@ -2,14 +2,16 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import InstalledApp
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -71,6 +73,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
@@ -80,3 +135,13 @@ class InstalledAppResource(Resource):
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
@@ -32,6 +32,7 @@ from controllers.console.wraps import (
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -128,6 +129,17 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("avatar", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -114,25 +114,6 @@ class PluginIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/asset")
|
||||
class PluginAssetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
req.add_argument("file_name", type=str, required=True, location="args")
|
||||
args = req.parse_args()
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
|
||||
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/pkg")
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@@ -577,21 +558,19 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
@@ -707,23 +686,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
args = req.parse_args()
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
class PluginReadmeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
parser.add_argument("language", type=str, required=False, location="args")
|
||||
args = parser.parse_args()
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"readme": PluginService.fetch_plugin_readme(
|
||||
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -21,14 +21,12 @@ from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
# from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
@@ -1,592 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
return jsonable_encoder({"error": str(e)}), 404
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
return TriggerSubscriptionBuilderService.update_and_verify_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
return 200
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise NotFoundError("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Failed to get OAuth credentials from the provider.")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
system_client_exists = TriggerProviderService.is_oauth_system_client_exists(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client_exists),
|
||||
"system_configured": system_client_exists,
|
||||
"custom_configured": bool(custom_params),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params or {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@@ -51,6 +51,9 @@ tenant_fields = {
|
||||
"in_trial": fields.Boolean,
|
||||
"trial_end_reason": fields.String,
|
||||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
"trial_credits": fields.Integer,
|
||||
"trial_credits_used": fields.Integer,
|
||||
"next_credit_reset_date": fields.Integer,
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
|
||||
@@ -19,6 +19,7 @@ from .app import (
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
chatflow_memory,
|
||||
completion,
|
||||
conversation,
|
||||
file,
|
||||
@@ -40,6 +41,7 @@ __all__ = [
|
||||
"annotation",
|
||||
"app",
|
||||
"audio",
|
||||
"chatflow_memory",
|
||||
"completion",
|
||||
"conversation",
|
||||
"dataset",
|
||||
|
||||
109
api/controllers/service_api/app/chatflow_memory.py
Normal file
109
api/controllers/service_api/app/chatflow_memory.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.memory.entities import MemoryBlock, MemoryCreatedBy
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from models import App, EndUser
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MemoryListApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("memory_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("version", required=False, type=int | None, default=None)
|
||||
args = parser.parse_args()
|
||||
conversation_id: str | None = args.get("conversation_id")
|
||||
memory_id = args.get("memory_id")
|
||||
version = args.get("version")
|
||||
|
||||
if conversation_id:
|
||||
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
result = [*result, *session_memories]
|
||||
else:
|
||||
result = ChatflowMemoryService.get_persistent_memories(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), version
|
||||
)
|
||||
|
||||
if memory_id:
|
||||
result = [it for it in result if it.spec.id == memory_id]
|
||||
return [it for it in result if it.spec.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def put(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True)
|
||||
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("node_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("update", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow = WorkflowService().get_published_workflow(app_model)
|
||||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {"error": "Invalid update"}, 400
|
||||
if not workflow:
|
||||
return {"error": "Workflow not found"}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args["id"]), None)
|
||||
if not memory_spec:
|
||||
return {"error": "Memory not found"}, 404
|
||||
|
||||
# First get existing memory
|
||||
existing_memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
# Create updated memory instance with incremented version
|
||||
updated_memory = MemoryBlock(
|
||||
spec=existing_memory.spec,
|
||||
tenant_id=existing_memory.tenant_id,
|
||||
app_id=existing_memory.app_id,
|
||||
conversation_id=existing_memory.conversation_id,
|
||||
node_id=existing_memory.node_id,
|
||||
value=update, # New value
|
||||
version=existing_memory.version + 1, # Increment version for update
|
||||
edited_by_user=True,
|
||||
created_by=existing_memory.created_by,
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
|
||||
return "", 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get("id")
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(app_model, memory_id, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(app_model, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, "/memories")
|
||||
api.add_resource(MemoryEditApi, "/memory-edit")
|
||||
api.add_resource(MemoryDeleteApi, "/memories")
|
||||
@@ -20,8 +20,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.end_user_service import EndUserService
|
||||
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
@@ -85,7 +84,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user(app_model, user_id)
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
kwargs["end_user"] = end_user
|
||||
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
@@ -332,6 +331,39 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
return api_token
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
|
||||
__all__ = [
|
||||
"trigger",
|
||||
"webhook",
|
||||
]
|
||||
@@ -1,43 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
response = None
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
logger.error("Endpoint not found for {endpoint_id}")
|
||||
return jsonify({"error": "Endpoint not found"}), 404
|
||||
return response
|
||||
except ValueError as e:
|
||||
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 400
|
||||
except Exception:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error"}), 500
|
||||
@@ -1,105 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
||||
"""Fetch trigger context, extract request data, and validate payload using unified processing.
|
||||
|
||||
Args:
|
||||
webhook_id: The webhook ID to process
|
||||
is_debug: If True, skip status validation for debug mode
|
||||
"""
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
webhook_id, is_debug=is_debug
|
||||
)
|
||||
|
||||
try:
|
||||
# Use new unified extraction and validation
|
||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
return webhook_trigger, workflow, node_config, webhook_data, None
|
||||
except ValueError as e:
|
||||
# Fall back to raw extraction for error reporting
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
return webhook_trigger, workflow, node_config, webhook_data, str(e)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, workflow, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id)
|
||||
if error:
|
||||
return jsonify({"error": "Bad Request", "message": error}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
|
||||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
return jsonify({"error": "Bad Request", "message": error}), 400
|
||||
|
||||
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Generate pool key and dispatch debug event
|
||||
pool_key: str = build_webhook_pool_key(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
node_id=webhook_trigger.node_id,
|
||||
)
|
||||
event = WebhookDebugEvent(
|
||||
request_id=f"webhook_debug_{webhook_trigger.webhook_id}_{int(time.time() * 1000)}",
|
||||
timestamp=int(time.time()),
|
||||
node_id=webhook_trigger.node_id,
|
||||
payload={
|
||||
"inputs": workflow_inputs,
|
||||
"webhook_data": webhook_data,
|
||||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook debug processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": "An internal error has occurred."}), 500
|
||||
@@ -18,6 +18,7 @@ web_ns = Namespace("web", description="Web application API operations", path="/"
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
chatflow_memory,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
@@ -39,6 +40,7 @@ __all__ = [
|
||||
"app",
|
||||
"audio",
|
||||
"bp",
|
||||
"chatflow_memory",
|
||||
"completion",
|
||||
"conversation",
|
||||
"feature",
|
||||
|
||||
108
api/controllers/web/chatflow_memory.py
Normal file
108
api/controllers/web/chatflow_memory.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from flask_restx import reqparse
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.memory.entities import MemoryBlock, MemoryCreatedBy
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from models import App, EndUser
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MemoryListApi(WebApiResource):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("memory_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("version", required=False, type=int | None, default=None)
|
||||
args = parser.parse_args()
|
||||
conversation_id: str | None = args.get("conversation_id")
|
||||
memory_id = args.get("memory_id")
|
||||
version = args.get("version")
|
||||
|
||||
if conversation_id:
|
||||
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
result = [*result, *session_memories]
|
||||
else:
|
||||
result = ChatflowMemoryService.get_persistent_memories(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), version
|
||||
)
|
||||
|
||||
if memory_id:
|
||||
result = [it for it in result if it.spec.id == memory_id]
|
||||
return [it for it in result if it.spec.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(WebApiResource):
|
||||
def put(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True)
|
||||
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("node_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("update", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow = WorkflowService().get_published_workflow(app_model)
|
||||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {"error": "Update must be a string"}, 400
|
||||
if not workflow:
|
||||
return {"error": "Workflow not found"}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args["id"]), None)
|
||||
if not memory_spec:
|
||||
return {"error": "Memory not found"}, 404
|
||||
if not memory_spec.end_user_editable:
|
||||
return {"error": "Memory not editable"}, 403
|
||||
|
||||
# First get existing memory
|
||||
existing_memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
# Create updated memory instance with incremented version
|
||||
updated_memory = MemoryBlock(
|
||||
spec=existing_memory.spec,
|
||||
tenant_id=existing_memory.tenant_id,
|
||||
app_id=existing_memory.app_id,
|
||||
conversation_id=existing_memory.conversation_id,
|
||||
node_id=existing_memory.node_id,
|
||||
value=update, # New value
|
||||
version=existing_memory.version + 1, # Increment version for update
|
||||
edited_by_user=True,
|
||||
created_by=existing_memory.created_by,
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
|
||||
return "", 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(WebApiResource):
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get("id")
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(app_model, memory_id, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(app_model, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, "/memories")
|
||||
api.add_resource(MemoryEditApi, "/memory-edit")
|
||||
api.add_resource(MemoryDeleteApi, "/memories")
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -20,6 +21,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope
|
||||
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
@@ -27,6 +30,7 @@ from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphRunSucceededEvent
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@@ -39,6 +43,8 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,6 +87,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def run(self):
|
||||
ChatflowMemoryService.wait_for_sync_memory_completion(
|
||||
workflow=self._workflow, conversation_id=self.conversation.id
|
||||
)
|
||||
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
@@ -143,6 +153,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
memory_blocks=self._fetch_memory_blocks(),
|
||||
)
|
||||
|
||||
# init graph
|
||||
@@ -206,6 +217,31 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
try:
|
||||
self._check_app_memory_updates(variable_pool)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to check app memory updates", exc_info=e)
|
||||
|
||||
@override
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: Any) -> None:
|
||||
super()._handle_event(workflow_entry, event)
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
workflow_outputs = event.outputs
|
||||
if not workflow_outputs:
|
||||
logger.warning("Chatflow output is empty.")
|
||||
return
|
||||
assistant_message = workflow_outputs.get("answer")
|
||||
if not assistant_message:
|
||||
logger.warning("Chatflow output does not contain 'answer'.")
|
||||
return
|
||||
if not isinstance(assistant_message, str):
|
||||
logger.warning("Chatflow output 'answer' is not a string.")
|
||||
return
|
||||
try:
|
||||
self._sync_conversation_to_chatflow_tables(assistant_message)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to sync conversation to memory tables", exc_info=e)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
@@ -403,3 +439,67 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# Return combined list
|
||||
return existing_variables + new_variables
|
||||
|
||||
def _fetch_memory_blocks(self) -> Mapping[str, str]:
|
||||
"""fetch all memory blocks for current app"""
|
||||
|
||||
memory_blocks_dict: MutableMapping[str, str] = {}
|
||||
is_draft = self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
conversation_id = self.conversation.id
|
||||
memory_block_specs = self._workflow.memory_blocks
|
||||
# Get runtime memory values
|
||||
memories = ChatflowMemoryService.get_memories_by_specs(
|
||||
memory_block_specs=memory_block_specs,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
node_id=None,
|
||||
conversation_id=conversation_id,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_created_by(),
|
||||
)
|
||||
|
||||
# Build memory_id -> value mapping
|
||||
for memory in memories:
|
||||
if memory.spec.scope == MemoryScope.APP:
|
||||
# App level: use memory_id directly
|
||||
memory_blocks_dict[memory.spec.id] = memory.value
|
||||
else: # NODE scope
|
||||
node_id = memory.node_id
|
||||
if not node_id:
|
||||
logger.warning("Memory block %s has no node_id, skip.", memory.spec.id)
|
||||
continue
|
||||
key = f"{node_id}.{memory.spec.id}"
|
||||
memory_blocks_dict[key] = memory.value
|
||||
|
||||
return memory_blocks_dict
|
||||
|
||||
def _sync_conversation_to_chatflow_tables(self, assistant_message: str):
|
||||
ChatflowHistoryService.save_app_message(
|
||||
prompt_message=UserPromptMessage(content=(self.application_generate_entity.query)),
|
||||
conversation_id=self.conversation.id,
|
||||
app_id=self._workflow.app_id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
ChatflowHistoryService.save_app_message(
|
||||
prompt_message=AssistantPromptMessage(content=assistant_message),
|
||||
conversation_id=self.conversation.id,
|
||||
app_id=self._workflow.app_id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
|
||||
def _check_app_memory_updates(self, variable_pool: VariablePool):
|
||||
is_draft = self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
|
||||
ChatflowMemoryService.update_app_memory_if_needed(
|
||||
workflow=self._workflow,
|
||||
conversation_id=self.conversation.id,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_created_by(),
|
||||
)
|
||||
|
||||
def _get_created_by(self) -> MemoryCreatedBy:
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
|
||||
else:
|
||||
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)
|
||||
|
||||
@@ -37,7 +37,6 @@ from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
@@ -304,11 +303,6 @@ class WorkflowResponseConverter:
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
@@ -47,7 +48,6 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -244,7 +244,34 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
||||
@@ -27,7 +27,6 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -39,16 +38,10 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@@ -60,10 +53,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@@ -76,9 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -92,10 +79,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -107,10 +91,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
@@ -145,20 +126,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# for trigger debug run, not prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -177,10 +155,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@@ -207,16 +182,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
"""
|
||||
@TBD
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
@@ -229,8 +196,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -266,10 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -463,8 +426,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -508,8 +469,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -9,7 +8,6 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -18,7 +16,6 @@ from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
@@ -38,21 +35,17 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: str | None = None,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@@ -67,7 +60,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
@@ -100,7 +92,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
|
||||
@@ -84,7 +84,6 @@ class WorkflowBasedAppRunner:
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: str | None = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
@@ -118,7 +117,7 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
@@ -32,10 +32,6 @@ class InvokeFrom(StrEnum):
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# TRIGGER indicates that this invocation is from a trigger.
|
||||
# this is used for plugin trigger and webhook trigger.
|
||||
TRIGGER = "trigger"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
@@ -44,9 +40,6 @@ class InvokeFrom(StrEnum):
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
@@ -72,8 +65,6 @@ class InvokeFrom(StrEnum):
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.TRIGGER:
|
||||
return "trigger"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return "api"
|
||||
|
||||
@@ -113,11 +104,6 @@ class AppGenerateEntity(BaseModel):
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[File]
|
||||
|
||||
# Unique identifier of the user initiating the execution.
|
||||
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
|
||||
#
|
||||
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
|
||||
@@ -1,64 +1,15 @@
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from models.model import AppMode
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
# Wrapper types for `WorkflowAppGenerateEntity` and
|
||||
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
|
||||
# and correct reconstruction of the entity field during (de)serialization.
|
||||
class _WorkflowGenerateEntityWrapper(BaseModel):
|
||||
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
|
||||
entity: WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
|
||||
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
|
||||
entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
_GenerateEntityUnion: TypeAlias = Annotated[
|
||||
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class WorkflowResumptionContext(BaseModel):
|
||||
"""WorkflowResumptionContext captures all state necessary for resumption."""
|
||||
|
||||
version: Literal["1"] = "1"
|
||||
|
||||
# Only workflow / chatflow could be paused.
|
||||
generate_entity: _GenerateEntityUnion
|
||||
serialized_graph_runtime_state: str
|
||||
|
||||
def dumps(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def loads(cls, value: str) -> Self:
|
||||
return cls.model_validate_json(value)
|
||||
|
||||
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
|
||||
return self.generate_entity.entity
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Engine | sessionmaker[Session],
|
||||
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
|
||||
state_owner_user_id: str,
|
||||
):
|
||||
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
@@ -68,7 +19,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
self._generate_entity = generate_entity
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
@@ -99,25 +49,13 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
else:
|
||||
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
state = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
|
||||
generate_entity=entity_wrapper,
|
||||
)
|
||||
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=state.dumps(),
|
||||
state=self.graph_runtime_state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
"""
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
pass
|
||||
@@ -1,88 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeSliceLayer(GraphEngineLayer):
|
||||
"""
|
||||
CFS plan scheduler to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
|
||||
|
||||
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
|
||||
"""
|
||||
CFS plan scheduler allows to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
if not TimeSliceLayer.scheduler.running:
|
||||
TimeSliceLayer.scheduler.start()
|
||||
|
||||
super().__init__()
|
||||
self.cfs_plan_scheduler = cfs_plan_scheduler
|
||||
self.stopped = False
|
||||
self.schedule_id = ""
|
||||
|
||||
def _checker_job(self, schedule_id: str):
|
||||
"""
|
||||
Check if the workflow need to be suspended.
|
||||
"""
|
||||
try:
|
||||
if self.stopped:
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
return
|
||||
|
||||
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
|
||||
# remove the job
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
|
||||
if not self.command_channel:
|
||||
logger.exception("No command channel to stop the workflow")
|
||||
return
|
||||
|
||||
# send command to pause the workflow
|
||||
self.command_channel.send_command(
|
||||
GraphEngineCommand(
|
||||
command_type=CommandType.PAUSE,
|
||||
payload={
|
||||
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
"""
|
||||
|
||||
if self.cfs_plan_scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice:
|
||||
self.schedule_id = uuid.uuid4().hex
|
||||
|
||||
self.scheduler.add_job(
|
||||
lambda: self._checker_job(self.schedule_id),
|
||||
"interval",
|
||||
seconds=self.cfs_plan_scheduler.plan.granularity,
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
if self.schedule_id:
|
||||
self.scheduler.remove_job(self.schedule_id)
|
||||
@@ -1,88 +0,0 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerPostLayer(GraphEngineLayer):
|
||||
"""
|
||||
Trigger post layer.
|
||||
"""
|
||||
|
||||
_STATUS_MAP: ClassVar[dict[type[GraphEngineEvent], WorkflowTriggerStatus]] = {
|
||||
GraphRunSucceededEvent: WorkflowTriggerStatus.SUCCEEDED,
|
||||
GraphRunFailedEvent: WorkflowTriggerStatus.FAILED,
|
||||
GraphRunPausedEvent: WorkflowTriggerStatus.PAUSED,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
|
||||
start_time: datetime,
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
self.session_maker = session_maker
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
"""
|
||||
if isinstance(event, tuple(self._STATUS_MAP.keys())):
|
||||
with self.session_maker() as session:
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = repo.get_by_id(self.trigger_log_id)
|
||||
if not trigger_log:
|
||||
logger.exception("Trigger log not found: %s", self.trigger_log_id)
|
||||
return
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id, "Workflow run id is not set"
|
||||
|
||||
total_tokens = self.graph_runtime_state.total_tokens
|
||||
|
||||
# Update trigger log with success
|
||||
trigger_log.status = self._STATUS_MAP[type(event)]
|
||||
trigger_log.workflow_run_id = workflow_run_id
|
||||
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
|
||||
|
||||
if trigger_log.elapsed_time is None:
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
else:
|
||||
trigger_log.elapsed_time += elapsed_time
|
||||
|
||||
trigger_log.total_tokens = total_tokens
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
@@ -1,15 +1,13 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: Sequence[str]
|
||||
document_ids: list[str]
|
||||
|
||||
@@ -14,6 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
@@ -271,7 +272,6 @@ class MCPProviderEntity(BaseModel):
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
@@ -14,7 +14,6 @@ class CommonParameterType(StrEnum):
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
CHECKBOX = "checkbox"
|
||||
ANY = auto()
|
||||
|
||||
# Dynamic select parameter
|
||||
|
||||
@@ -1533,9 +1533,6 @@ class ProviderConfiguration(BaseModel):
|
||||
# Return composite sort key: (model_type value, model position index)
|
||||
return (model.model_type.value, position_index)
|
||||
|
||||
# Deduplicate
|
||||
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
|
||||
|
||||
# Sort using the composite sort key
|
||||
return sorted(provider_models, key=get_sort_key)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict | None
|
||||
credentials: dict | None = None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
@@ -207,7 +207,6 @@ class ProviderConfig(BasicProviderConfig):
|
||||
required: bool = False
|
||||
default: Union[int, str, float, bool] | None = None
|
||||
options: list[Option] | None = None
|
||||
multiple: bool | None = False
|
||||
label: I18nObject | None = None
|
||||
help: I18nObject | None = None
|
||||
url: str | None = None
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = dict(self._deep_copy(data))
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = dict(self._deep_copy(data))
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def mask_plugin_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = dict(self._deep_copy(data))
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(dict(data))
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
@@ -56,6 +56,9 @@ class HostingConfiguration:
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@@ -128,7 +131,7 @@ class HostingConfiguration:
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
@@ -156,18 +159,49 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
def init_gemini(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_GEMINI_API_BASE:
|
||||
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
@@ -185,6 +219,66 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_xai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_XAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_XAI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_XAI_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_deepseek(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
|
||||
@@ -14,10 +14,12 @@ from core.llm_generator.prompts import (
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.memory.entities import MemoryBlock, MemoryBlockSpec
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
@@ -28,6 +30,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
@@ -562,3 +565,35 @@ class LLMGenerator:
|
||||
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
|
||||
)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def update_memory_block(
|
||||
tenant_id: str,
|
||||
visible_history: Sequence[tuple[str, str]],
|
||||
variable_pool: VariablePool,
|
||||
memory_block: MemoryBlock,
|
||||
memory_spec: MemoryBlockSpec,
|
||||
) -> str:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=memory_spec.model.provider,
|
||||
model=memory_spec.model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
formatted_history = ""
|
||||
for sender, message in visible_history:
|
||||
formatted_history += f"{sender}: {message}\n"
|
||||
filled_instruction = variable_pool.convert_template(memory_spec.instruction).text
|
||||
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
|
||||
inputs={
|
||||
"formatted_history": formatted_history,
|
||||
"current_value": memory_block.value,
|
||||
"instruction": filled_instruction,
|
||||
}
|
||||
)
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content=formatted_prompt)],
|
||||
model_parameters=memory_spec.model.completion_params,
|
||||
stream=False,
|
||||
)
|
||||
return llm_result.message.get_text_content()
|
||||
|
||||
@@ -422,3 +422,18 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
MEMORY_UPDATE_PROMPT = """
|
||||
Based on the following conversation history, update the memory content:
|
||||
|
||||
Conversation history:
|
||||
{{formatted_history}}
|
||||
|
||||
Current memory:
|
||||
{{current_value}}
|
||||
|
||||
Update instruction:
|
||||
{{instruction}}
|
||||
|
||||
Please output only the updated memory content, no other text like greeting:
|
||||
"""
|
||||
|
||||
119
api/core/memory/entities.py
Normal file
119
api/core/memory/entities.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
|
||||
|
||||
class MemoryScope(StrEnum):
|
||||
"""Memory scope determined by node_id field"""
|
||||
APP = "app" # node_id is None
|
||||
NODE = "node" # node_id is not None
|
||||
|
||||
|
||||
class MemoryTerm(StrEnum):
|
||||
"""Memory term determined by conversation_id field"""
|
||||
SESSION = "session" # conversation_id is not None
|
||||
PERSISTENT = "persistent" # conversation_id is None
|
||||
|
||||
|
||||
class MemoryStrategy(StrEnum):
|
||||
ON_TURNS = "on_turns"
|
||||
|
||||
|
||||
class MemoryScheduleMode(StrEnum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
|
||||
|
||||
class MemoryBlockSpec(BaseModel):
|
||||
"""Memory block specification for workflow configuration"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
description="Unique identifier for the memory block",
|
||||
)
|
||||
name: str = Field(description="Display name of the memory block")
|
||||
description: str = Field(default="", description="Description of the memory block")
|
||||
template: str = Field(description="Initial template content for the memory")
|
||||
instruction: str = Field(description="Instructions for updating the memory")
|
||||
scope: MemoryScope = Field(description="Scope of the memory (app or node level)")
|
||||
term: MemoryTerm = Field(description="Term of the memory (session or persistent)")
|
||||
strategy: MemoryStrategy = Field(description="Update strategy for the memory")
|
||||
update_turns: int = Field(gt=0, description="Number of turns between updates")
|
||||
preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve")
|
||||
schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode")
|
||||
model: ModelConfig = Field(description="Model configuration for memory updates")
|
||||
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
|
||||
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
|
||||
|
||||
|
||||
class MemoryCreatedBy(BaseModel):
|
||||
end_user_id: str | None = None
|
||||
account_id: str | None = None
|
||||
|
||||
|
||||
class MemoryBlock(BaseModel):
|
||||
"""Runtime memory block instance
|
||||
|
||||
Design Rules:
|
||||
- app_id = None: Global memory (future feature, not implemented yet)
|
||||
- app_id = str: App-specific memory
|
||||
- conversation_id = None: Persistent memory (cross-conversation)
|
||||
- conversation_id = str: Session memory (conversation-specific)
|
||||
- node_id = None: App-level scope
|
||||
- node_id = str: Node-level scope
|
||||
|
||||
These rules implicitly determine scope and term without redundant storage.
|
||||
"""
|
||||
spec: MemoryBlockSpec
|
||||
tenant_id: str
|
||||
value: str
|
||||
app_id: str
|
||||
conversation_id: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
edited_by_user: bool = False
|
||||
created_by: MemoryCreatedBy
|
||||
version: int = Field(description="Memory block version number")
|
||||
|
||||
|
||||
class MemoryValueData(BaseModel):
|
||||
value: str
|
||||
edited_by_user: bool = False
|
||||
|
||||
|
||||
class ChatflowConversationMetadata(BaseModel):
|
||||
"""Metadata for chatflow conversation with visible message count"""
|
||||
type: str = "mutable_visible_window"
|
||||
visible_count: int = Field(gt=0, description="Number of visible messages to keep")
|
||||
|
||||
|
||||
class MemoryBlockWithConversation(MemoryBlock):
|
||||
"""MemoryBlock with optional conversation metadata for session memories"""
|
||||
conversation_metadata: ChatflowConversationMetadata = Field(
|
||||
description="Conversation metadata, only present for session memories"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_memory_block(
|
||||
cls,
|
||||
memory_block: MemoryBlock,
|
||||
conversation_metadata: ChatflowConversationMetadata
|
||||
) -> MemoryBlockWithConversation:
|
||||
"""Create MemoryBlockWithConversation from MemoryBlock"""
|
||||
return cls(
|
||||
spec=memory_block.spec,
|
||||
tenant_id=memory_block.tenant_id,
|
||||
value=memory_block.value,
|
||||
app_id=memory_block.app_id,
|
||||
conversation_id=memory_block.conversation_id,
|
||||
node_id=memory_block.node_id,
|
||||
edited_by_user=memory_block.edited_by_user,
|
||||
created_by=memory_block.created_by,
|
||||
version=memory_block.version,
|
||||
conversation_metadata=conversation_metadata
|
||||
)
|
||||
6
api/core/memory/errors.py
Normal file
6
api/core/memory/errors.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class MemorySyncTimeoutError(Exception):
|
||||
def __init__(self, app_id: str, conversation_id: str):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.message = "Memory synchronization timeout after 50 seconds"
|
||||
super().__init__(self.message)
|
||||
@@ -1,22 +1,21 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
@@ -31,10 +30,9 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -101,45 +99,22 @@ def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
string_stacktrace = "".join(traceback.format_exception(error))
|
||||
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
|
||||
else:
|
||||
error_message = str(error)
|
||||
return error_message
|
||||
def string_to_trace_id128(string: str | None) -> int:
|
||||
"""
|
||||
Convert any input string into a stable 128-bit integer trace ID.
|
||||
|
||||
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
|
||||
It's suitable for generating consistent, unique identifiers from strings.
|
||||
"""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
# Take the first 16 bytes (128 bits) of the hash digest
|
||||
digest = hash_object.digest()[:16]
|
||||
|
||||
if isinstance(error, Exception):
|
||||
current_span.record_exception(error)
|
||||
else:
|
||||
exception_type = error.__class__.__name__
|
||||
exception_message = str(error)
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
current_span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
# Convert to a 128-bit integer
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
@@ -156,12 +131,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
|
||||
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
@@ -179,7 +151,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
|
||||
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
@@ -194,9 +166,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
@@ -208,58 +186,31 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
for node_execution in workflow_node_executions:
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
node_metadata.update(
|
||||
{
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
|
||||
}
|
||||
)
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
@@ -272,9 +223,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
@@ -286,20 +236,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=workflow_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -313,8 +260,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
)
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
@@ -325,16 +275,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
if node_execution.status == "failed":
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
@@ -380,18 +322,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
@@ -425,26 +383,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
message_span_context = set_span_in_context(message_span)
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=message_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(llm_span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(llm_span)
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(message_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(message_span)
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
@@ -460,9 +418,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
@@ -481,14 +445,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
@@ -511,9 +480,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
@@ -524,14 +499,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
@@ -553,9 +533,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
@@ -568,14 +554,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
@@ -589,9 +580,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
@@ -610,14 +612,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
@@ -634,9 +641,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
@@ -650,34 +663,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
@@ -697,6 +698,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Union
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
@@ -15,7 +16,6 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@@ -64,7 +64,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
app = cls._get_app(app_id, tenant_id)
|
||||
if not user_id:
|
||||
user = EndUserService.get_or_create_end_user(app)
|
||||
user = create_or_update_end_user_for_user_id(app)
|
||||
else:
|
||||
user = cls._get_user(user_id)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class PluginParameterType(StrEnum):
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
|
||||
CHECKBOX = CommonParameterType.CHECKBOX
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
@@ -94,7 +94,6 @@ def as_normal_type(typ: StrEnum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
PluginParameterType.CHECKBOX,
|
||||
}:
|
||||
return "string"
|
||||
return typ.value
|
||||
@@ -103,13 +102,7 @@ def as_normal_type(typ: StrEnum):
|
||||
def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case (
|
||||
PluginParameterType.STRING
|
||||
| PluginParameterType.SECRET_INPUT
|
||||
| PluginParameterType.SELECT
|
||||
| PluginParameterType.CHECKBOX
|
||||
| PluginParameterType.DYNAMIC_SELECT
|
||||
):
|
||||
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,6 @@ from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(StrEnum):
|
||||
@@ -64,7 +63,6 @@ class PluginCategory(StrEnum):
|
||||
Extension = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
Datasource = "datasource"
|
||||
Trigger = "trigger"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
@@ -73,7 +71,6 @@ class PluginDeclaration(BaseModel):
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
triggers: list[str] | None = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
@@ -109,7 +106,6 @@ class PluginDeclaration(BaseModel):
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
trigger: TriggerProviderEntity | None = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
@@ -133,8 +129,6 @@ class PluginDeclaration(BaseModel):
|
||||
values["category"] = PluginCategory.Datasource
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
values["category"] = PluginCategory.Trigger
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@@ -15,7 +14,6 @@ from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
@@ -207,53 +205,3 @@ class PluginListResponse(BaseModel):
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
UNAUTHORIZED = "unauthorized"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
|
||||
class PluginReadmeResponse(BaseModel):
|
||||
content: str = Field(description="The readme of the plugin.")
|
||||
language: str = Field(description="The language of the readme.")
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import binascii
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
@@ -17,7 +13,6 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.utils.http_parser import deserialize_response
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
@@ -242,43 +237,3 @@ class RequestFetchAppInfo(BaseModel):
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class TriggerInvokeEventResponse(BaseModel):
|
||||
variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
cancelled: bool = Field(default=False)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("variables", mode="before")
|
||||
@classmethod
|
||||
def convert_variables(cls, v):
|
||||
if isinstance(v, str):
|
||||
return json.loads(v)
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse(BaseModel):
|
||||
user_id: str
|
||||
events: list[str]
|
||||
response: Response
|
||||
payload: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("response", mode="before")
|
||||
@classmethod
|
||||
def convert_response(cls, v: str):
|
||||
try:
|
||||
return deserialize_response(binascii.unhexlify(v.encode()))
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to deserialize response from hex string") from e
|
||||
|
||||
@@ -10,13 +10,3 @@ class PluginAssetManager(BasePluginClient):
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"can not found asset {id}")
|
||||
return response.content
|
||||
|
||||
def extract_asset(self, tenant_id: str, plugin_unique_identifier: str, filename: str) -> bytes:
|
||||
response = self._request(
|
||||
method="GET",
|
||||
path=f"plugin/{tenant_id}/extract-asset/",
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier, "file_path": filename},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"can not found asset {plugin_unique_identifier}, {str(response.status_code)}")
|
||||
return response.content
|
||||
|
||||
@@ -29,12 +29,6 @@ from core.plugin.impl.exc import (
|
||||
PluginPermissionDeniedError,
|
||||
PluginUniqueIdentifierError,
|
||||
)
|
||||
from core.trigger.errors import (
|
||||
EventIgnoreError,
|
||||
TriggerInvokeError,
|
||||
TriggerPluginInvokeError,
|
||||
TriggerProviderCredentialValidationError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
@@ -49,7 +43,7 @@ elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
||||
else:
|
||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,10 +53,10 @@ class BasePluginClient:
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | str | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
@@ -93,17 +87,17 @@ class BasePluginClient:
|
||||
def _prepare_request(
|
||||
self,
|
||||
path: str,
|
||||
headers: dict[str, str] | None,
|
||||
data: bytes | dict[str, Any] | str | None,
|
||||
params: dict[str, Any] | None,
|
||||
files: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
|
||||
headers: dict | None,
|
||||
data: bytes | dict | str | None,
|
||||
params: dict | None,
|
||||
files: dict | None,
|
||||
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
prepared_headers = dict(headers or {})
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
prepared_data: bytes | dict | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
@@ -118,10 +112,10 @@ class BasePluginClient:
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
params: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API
|
||||
@@ -144,7 +138,7 @@ class BasePluginClient:
|
||||
try:
|
||||
with httpx.stream(**stream_kwargs) as response:
|
||||
for raw_line in response.iter_lines():
|
||||
if not raw_line:
|
||||
if raw_line is None:
|
||||
continue
|
||||
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
|
||||
line = line.strip()
|
||||
@@ -161,10 +155,10 @@ class BasePluginClient:
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
@@ -177,10 +171,10 @@ class BasePluginClient:
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
@@ -193,11 +187,11 @@ class BasePluginClient:
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
transformer: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
transformer: Callable[[dict], dict] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
@@ -245,10 +239,10 @@ class BasePluginClient:
|
||||
method: str,
|
||||
path: str,
|
||||
type_: type[T],
|
||||
headers: dict[str, str] | None = None,
|
||||
data: bytes | dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
@@ -308,14 +302,6 @@ class BasePluginClient:
|
||||
raise CredentialsValidateFailedError(error_object.get("message"))
|
||||
case EndpointSetupFailedError.__name__:
|
||||
raise EndpointSetupFailedError(error_object.get("message"))
|
||||
case TriggerProviderCredentialValidationError.__name__:
|
||||
raise TriggerProviderCredentialValidationError(error_object.get("message"))
|
||||
case TriggerPluginInvokeError.__name__:
|
||||
raise TriggerPluginInvokeError(description=error_object.get("description"))
|
||||
case TriggerInvokeError.__name__:
|
||||
raise TriggerInvokeError(error_object.get("message"))
|
||||
case EventIgnoreError.__name__:
|
||||
raise EventIgnoreError(description=error_object.get("description"))
|
||||
case _:
|
||||
raise PluginInvokeError(description=message)
|
||||
case PluginDaemonInternalServerError.__name__:
|
||||
|
||||
@@ -15,7 +15,6 @@ class DynamicSelectClient(BasePluginClient):
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
credential_type: str,
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
@@ -30,7 +29,6 @@ class DynamicSelectClient(BasePluginClient):
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
|
||||
@@ -58,20 +58,6 @@ class PluginInvokeError(PluginDaemonClientSideError, ValueError):
|
||||
except Exception:
|
||||
return self.description
|
||||
|
||||
def to_user_friendly_error(self, plugin_name: str = "currently running plugin") -> str:
|
||||
"""
|
||||
Convert the error to a user-friendly error message.
|
||||
|
||||
:param plugin_name: The name of the plugin that caused the error.
|
||||
:return: A user-friendly error message.
|
||||
"""
|
||||
return (
|
||||
f"An error occurred in the {plugin_name}, "
|
||||
f"please contact the author of {plugin_name} for help, "
|
||||
f"error type: {self.get_error_type()}, "
|
||||
f"error details: {self.get_error_message()}"
|
||||
)
|
||||
|
||||
|
||||
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
|
||||
description: str = "Unique Identifier Error"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
MissingPluginDependency,
|
||||
@@ -15,35 +13,12 @@ from core.plugin.entities.plugin_daemon import (
|
||||
PluginInstallTask,
|
||||
PluginInstallTaskStartResponse,
|
||||
PluginListResponse,
|
||||
PluginReadmeResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginInstaller(BasePluginClient):
|
||||
def fetch_plugin_readme(self, tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
|
||||
"""
|
||||
Fetch plugin readme
|
||||
"""
|
||||
try:
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/fetch/readme",
|
||||
PluginReadmeResponse,
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"language": language,
|
||||
},
|
||||
)
|
||||
return response.content
|
||||
except HTTPError as e:
|
||||
message = e.args[0]
|
||||
if "404" in message:
|
||||
return ""
|
||||
raise e
|
||||
|
||||
def fetch_plugin_by_identifier(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -3,12 +3,14 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginToolProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||
|
||||
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import (
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeEventResponse,
|
||||
TriggerSubscriptionResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.http_parser import serialize_request
|
||||
from core.trigger.entities.entities import Subscription
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
|
||||
class PluginTriggerClient(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch trigger providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
|
||||
for event in declaration.get("events", []):
|
||||
event["identity"]["provider"] = provider_id
|
||||
|
||||
return json_response
|
||||
|
||||
response: list[PluginTriggerProviderEntity] = self._request_with_plugin_daemon_response(
|
||||
method="GET",
|
||||
path=f"plugin/{tenant_id}/management/triggers",
|
||||
type_=list[PluginTriggerProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for event in provider.declaration.events:
|
||||
event.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch trigger provider for the given tenant and plugin.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for event in data.get("declaration", {}).get("events", []):
|
||||
event["identity"]["provider"] = str(provider_id)
|
||||
|
||||
return json_response
|
||||
|
||||
response: PluginTriggerProviderEntity = self._request_with_plugin_daemon_response(
|
||||
method="GET",
|
||||
path=f"plugin/{tenant_id}/management/trigger",
|
||||
type_=PluginTriggerProviderEntity,
|
||||
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = str(provider_id)
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for event in response.declaration.events:
|
||||
event.identity.provider = str(provider_id)
|
||||
|
||||
return response
|
||||
|
||||
def invoke_trigger_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
event_name: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: Mapping[str, Any],
|
||||
subscription: Subscription,
|
||||
payload: Mapping[str, Any],
|
||||
) -> TriggerInvokeEventResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerInvokeEventResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/invoke_event",
|
||||
type_=TriggerInvokeEventResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"event": event_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"subscription": subscription.model_dump(),
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
"parameters": parameters,
|
||||
"payload": payload,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerValidateProviderCredentialsResponse, None, None] = (
|
||||
self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||
type_=TriggerValidateProviderCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
subscription: Mapping[str, Any],
|
||||
request: Request,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch an event to triggers.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||
type_=TriggerDispatchResponse,
|
||||
data={
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"subscription": subscription,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
endpoint: str,
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||
type_=TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"endpoint": endpoint,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for subscribe")
|
||||
|
||||
def unsubscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||
type_=TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
"""
|
||||
provider_id = TriggerProviderID(provider)
|
||||
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||
type_=TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for refresh")
|
||||
@@ -1,163 +0,0 @@
|
||||
from io import BytesIO
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
method = request.method
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
if "Content-Type" in headers:
|
||||
content_type = headers.get("Content-Type")
|
||||
if content_type is not None:
|
||||
environ["CONTENT_TYPE"] = content_type
|
||||
|
||||
if "Content-Length" in headers:
|
||||
content_length = headers.get("Content-Length")
|
||||
if content_length is not None:
|
||||
environ["CONTENT_LENGTH"] = content_length
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
@@ -45,6 +45,12 @@ class MemoryConfig(BaseModel):
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
mode: Literal["linear", "block"] | None = "linear"
|
||||
block_id: list[str] | None = None
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
@property
|
||||
def is_block_mode(self) -> bool:
|
||||
return self.mode == "block" and bool(self.block_id)
|
||||
|
||||
@@ -618,18 +618,18 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type == ProviderQuotaType.TRIAL:
|
||||
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
if quota.quota_type not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
|
||||
new_provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=quota.quota_type,
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
@@ -641,8 +641,8 @@ class ProviderManager:
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota.quota_type,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
@@ -652,7 +652,7 @@ class ProviderManager:
|
||||
existed_provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@@ -912,6 +912,22 @@ class ProviderManager:
|
||||
provider_record
|
||||
)
|
||||
quota_configurations = []
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
paid_pool = None
|
||||
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
|
||||
if provider_quota.quota_type == ProviderQuotaType.FREE:
|
||||
@@ -932,16 +948,36 @@ class ProviderManager:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
else:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
|
||||
@@ -147,8 +147,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
# remove any suffix like "-SNAPSHOT" from the version string
|
||||
return cast(str, info["version"]["number"]).split("-")[0]
|
||||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
|
||||
@@ -289,7 +289,8 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
# nr: person name, ns: place name, nt: organization name
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
|
||||
@@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
||||
# Vastbase supports vector dimensions in range [1, 16000]
|
||||
if dimension <= 16000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@@ -92,7 +92,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# Urls without scheme won't be parsed correctly in some python verions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
|
||||
@@ -1,79 +1,89 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
|
||||
TASK_WRAPPER_PREFIX = "__WRAPPER__:"
|
||||
|
||||
|
||||
class TaskWrapper(BaseModel):
|
||||
@dataclass
|
||||
class TaskWrapper:
|
||||
data: Any
|
||||
|
||||
|
||||
def serialize(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
return json.dumps(self.data, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||||
return cls.model_validate_json(serialized_data)
|
||||
def deserialize(cls, serialized_data: str) -> 'TaskWrapper':
|
||||
data = json.loads(serialized_data)
|
||||
return cls(data)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueue:
|
||||
class TenantSelfTaskQueue:
|
||||
"""
|
||||
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
|
||||
Simple queue for tenant self tasks, used for tenant self task isolation.
|
||||
It uses Redis list to store tasks, and Redis key to store task waiting flag.
|
||||
Support tasks that can be serialized by json.
|
||||
"""
|
||||
DEFAULT_TASK_TTL = 60 * 60
|
||||
|
||||
def __init__(self, tenant_id: str, unique_key: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._unique_key = unique_key
|
||||
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||||
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||||
self.tenant_id = tenant_id
|
||||
self.unique_key = unique_key
|
||||
self.queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||||
self.task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||||
|
||||
def get_task_key(self):
|
||||
return redis_client.get(self._task_key)
|
||||
return redis_client.get(self.task_key)
|
||||
|
||||
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
|
||||
redis_client.setex(self._task_key, ttl, 1)
|
||||
def set_task_waiting_time(self, ttl: int | None = None):
|
||||
ttl = ttl or self.DEFAULT_TASK_TTL
|
||||
redis_client.setex(self.task_key, ttl, 1)
|
||||
|
||||
def delete_task_key(self):
|
||||
redis_client.delete(self._task_key)
|
||||
redis_client.delete(self.task_key)
|
||||
|
||||
def push_tasks(self, tasks: Sequence[Any]):
|
||||
def push_tasks(self, tasks: list):
|
||||
serialized_tasks = []
|
||||
for task in tasks:
|
||||
# Store str list directly, maintaining full compatibility for pipeline scenarios
|
||||
if isinstance(task, str):
|
||||
serialized_tasks.append(task)
|
||||
else:
|
||||
# Use TaskWrapper to do JSON serialization for non-string tasks
|
||||
wrapper = TaskWrapper(data=task)
|
||||
# Use TaskWrapper to do JSON serialization, add prefix for identification
|
||||
wrapper = TaskWrapper(task)
|
||||
serialized_data = wrapper.serialize()
|
||||
serialized_tasks.append(serialized_data)
|
||||
|
||||
redis_client.lpush(self._queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
|
||||
serialized_tasks.append(f"{TASK_WRAPPER_PREFIX}{serialized_data}")
|
||||
|
||||
redis_client.lpush(self.queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> list:
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
|
||||
tasks = []
|
||||
for _ in range(count):
|
||||
serialized_task = redis_client.rpop(self._queue)
|
||||
serialized_task = redis_client.rpop(self.queue)
|
||||
if not serialized_task:
|
||||
break
|
||||
|
||||
|
||||
if isinstance(serialized_task, bytes):
|
||||
serialized_task = serialized_task.decode("utf-8")
|
||||
|
||||
try:
|
||||
wrapper = TaskWrapper.deserialize(serialized_task)
|
||||
tasks.append(wrapper.data)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
|
||||
# Fall back to raw string for legacy format or invalid JSON
|
||||
serialized_task = serialized_task.decode('utf-8')
|
||||
|
||||
# Check if use TaskWrapper or not
|
||||
if serialized_task.startswith(TASK_WRAPPER_PREFIX):
|
||||
try:
|
||||
wrapper_data = serialized_task[len(TASK_WRAPPER_PREFIX):]
|
||||
wrapper = TaskWrapper.deserialize(wrapper_data)
|
||||
tasks.append(wrapper.data)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
tasks.append(serialized_task)
|
||||
else:
|
||||
tasks.append(serialized_task)
|
||||
|
||||
|
||||
return tasks
|
||||
|
||||
def get_next_task(self):
|
||||
tasks = self.pull_tasks(1)
|
||||
return tasks[0] if tasks else None
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import (
|
||||
CredentialType,
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderEntity,
|
||||
|
||||
@@ -6,10 +6,9 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
|
||||
@@ -268,7 +268,6 @@ class ToolParameter(PluginParameter):
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
CHECKBOX = PluginParameterType.CHECKBOX
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
|
||||
ANY = PluginParameterType.ANY
|
||||
@@ -490,3 +489,36 @@ class ToolSelector(BaseModel):
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = auto()
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
@@ -55,11 +52,6 @@ class MCPTool(Tool):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
@@ -105,10 +97,6 @@ class MCPTool(Tool):
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@@ -8,6 +8,7 @@ from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
@@ -38,7 +39,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
@@ -49,6 +49,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
@@ -288,8 +289,10 @@ class ToolManager:
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
builtin_provider.encrypted_credentials = (
|
||||
TypeAdapter(dict[str, Any])
|
||||
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
.decode("utf-8")
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
@@ -319,7 +322,7 @@ class ToolManager:
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
credentials=encrypter.decrypt(credentials),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
@@ -830,7 +833,7 @@ class ToolManager:
|
||||
controller=controller,
|
||||
)
|
||||
|
||||
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
||||
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
|
||||
@@ -1,24 +1,137 @@
|
||||
# Import generic components from provider_encryption module
|
||||
from core.helper.provider_encryption import (
|
||||
ProviderConfigCache,
|
||||
ProviderConfigEncrypter,
|
||||
create_provider_encrypter,
|
||||
)
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ProviderConfigCache",
|
||||
"ProviderConfigEncrypter",
|
||||
"create_provider_encrypter",
|
||||
"create_tool_provider_encrypter",
|
||||
]
|
||||
|
||||
# Tool-specific imports
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> dict | None:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]):
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self):
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(
|
||||
tenant_id: str, controller: ToolProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Core trigger module initialization
|
||||
@@ -1,124 +0,0 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TypeVar
|
||||
|
||||
from redis import RedisError
|
||||
|
||||
from core.trigger.debug.events import BaseDebugEvent
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TRIGGER_DEBUG_EVENT_TTL = 300
|
||||
|
||||
TTriggerDebugEvent = TypeVar("TTriggerDebugEvent", bound="BaseDebugEvent")
|
||||
|
||||
|
||||
class TriggerDebugEventBus:
|
||||
"""
|
||||
Unified Redis-based trigger debug service with polling support.
|
||||
|
||||
Uses {tenant_id} hash tags for Redis Cluster compatibility.
|
||||
Supports multiple event types through a generic dispatch/poll interface.
|
||||
"""
|
||||
|
||||
# LUA_SELECT: Atomic poll or register for event
|
||||
# KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id}
|
||||
# KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:...
|
||||
# ARGV[1] = address_id
|
||||
LUA_SELECT = (
|
||||
"local v=redis.call('GET',KEYS[1]);"
|
||||
"if v then redis.call('DEL',KEYS[1]);return v end;"
|
||||
"redis.call('SADD',KEYS[2],ARGV[1]);"
|
||||
f"redis.call('EXPIRE',KEYS[2],{TRIGGER_DEBUG_EVENT_TTL});"
|
||||
"return false"
|
||||
)
|
||||
|
||||
# LUA_DISPATCH: Dispatch event to all waiting addresses
|
||||
# KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:...
|
||||
# ARGV[1] = tenant_id
|
||||
# ARGV[2] = event_json
|
||||
LUA_DISPATCH = (
|
||||
"local a=redis.call('SMEMBERS',KEYS[1]);"
|
||||
"if #a==0 then return 0 end;"
|
||||
"redis.call('DEL',KEYS[1]);"
|
||||
"for i=1,#a do "
|
||||
f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
|
||||
"end;"
|
||||
"return #a"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dispatch(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
event: BaseDebugEvent,
|
||||
pool_key: str,
|
||||
) -> int:
|
||||
"""
|
||||
Dispatch event to all waiting addresses in the pool.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for hash tag
|
||||
event: Event object to dispatch
|
||||
pool_key: Pool key (generate using build_{?}_pool_key(...))
|
||||
|
||||
Returns:
|
||||
Number of addresses the event was dispatched to
|
||||
"""
|
||||
event_data = event.model_dump_json()
|
||||
try:
|
||||
result = redis_client.eval(
|
||||
cls.LUA_DISPATCH,
|
||||
1,
|
||||
pool_key,
|
||||
tenant_id,
|
||||
event_data,
|
||||
)
|
||||
return int(result)
|
||||
except RedisError:
|
||||
logger.exception("Failed to dispatch event to pool: %s", pool_key)
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def poll(
|
||||
cls,
|
||||
event_type: type[TTriggerDebugEvent],
|
||||
pool_key: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
) -> TTriggerDebugEvent | None:
|
||||
"""
|
||||
Poll for an event or register to the waiting pool.
|
||||
|
||||
If an event is available in the inbox, return it immediately.
|
||||
Otherwise, register the address to the waiting pool for future dispatch.
|
||||
|
||||
Args:
|
||||
event_class: Event class for deserialization and type safety
|
||||
pool_key: Pool key (generate using build_{?}_pool_key(...))
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID for address calculation
|
||||
app_id: App ID for address calculation
|
||||
node_id: Node ID for address calculation
|
||||
|
||||
Returns:
|
||||
Event object if available, None otherwise
|
||||
"""
|
||||
address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
|
||||
address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}"
|
||||
|
||||
try:
|
||||
event_data = redis_client.eval(
|
||||
cls.LUA_SELECT,
|
||||
2,
|
||||
address,
|
||||
pool_key,
|
||||
address_id,
|
||||
)
|
||||
return event_type.model_validate_json(json_data=event_data) if event_data else None
|
||||
except RedisError:
|
||||
logger.exception("Failed to poll event from pool: %s", pool_key)
|
||||
return None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user