mirror of
https://github.com/langgenius/dify.git
synced 2026-04-12 15:39:25 +08:00
Compare commits
89 Commits
feat/email
...
test/perfo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d34c95bf8e | ||
|
|
1df1ffa2ec | ||
|
|
947dbd8854 | ||
|
|
ce619287b3 | ||
|
|
cd1ec65286 | ||
|
|
bdb9f29948 | ||
|
|
66cc1b4308 | ||
|
|
d52fb18457 | ||
|
|
4a2169bd5f | ||
|
|
2c9ee54a16 | ||
|
|
aef67ed7ec | ||
|
|
ddfd8c8525 | ||
|
|
2c1ab4879f | ||
|
|
229b4d621e | ||
|
|
0dee41c074 | ||
|
|
bf542233a9 | ||
|
|
38106074b4 | ||
|
|
1f4b3591ae | ||
|
|
7bf3d2c8bf | ||
|
|
da53bf511f | ||
|
|
7388fd1ec6 | ||
|
|
b803eeb528 | ||
|
|
14f79ee652 | ||
|
|
df89629e04 | ||
|
|
d427088ab5 | ||
|
|
32c541a9ed | ||
|
|
7e666dc3b1 | ||
|
|
5247c19498 | ||
|
|
9823edd3a2 | ||
|
|
88537991d6 | ||
|
|
8e910d8c59 | ||
|
|
a0b32b6027 | ||
|
|
bf7b2c339b | ||
|
|
a1dfe6d402 | ||
|
|
d2a3e8b9b1 | ||
|
|
ebb88bbe0b | ||
|
|
b690a9d839 | ||
|
|
9d9423808e | ||
|
|
3e96c0c468 | ||
|
|
6eb155ae69 | ||
|
|
b27c540379 | ||
|
|
8b1f428ead | ||
|
|
1d54ffcf89 | ||
|
|
d9eb5554b3 | ||
|
|
da94bdeb54 | ||
|
|
27e5e2745b | ||
|
|
1b26f9a4c6 | ||
|
|
df886259bd | ||
|
|
016ff0feae | ||
|
|
aa6cad5f1d | ||
|
|
e7388779a1 | ||
|
|
6c233e05a9 | ||
|
|
9f013f7644 | ||
|
|
253d8e5a5f | ||
|
|
7f5087c6db | ||
|
|
817071e448 | ||
|
|
f193e9764e | ||
|
|
5f9628e027 | ||
|
|
76d21743fd | ||
|
|
2d3c5b3b7c | ||
|
|
1d85979a74 | ||
|
|
2a85f28963 | ||
|
|
fe4e2f7921 | ||
|
|
9a9ec0c99b | ||
|
|
d5624ba671 | ||
|
|
c805238471 | ||
|
|
e576b989b8 | ||
|
|
f929bfb94c | ||
|
|
f4df80e093 | ||
|
|
390e4cc0bf | ||
|
|
11f9a897e8 | ||
|
|
0e793a660d | ||
|
|
7b2cab5767 | ||
|
|
c51b4290dc | ||
|
|
94a13d7d62 | ||
|
|
edf5fd28c9 | ||
|
|
b834131f50 | ||
|
|
5375d9bb27 | ||
|
|
535fff62f3 | ||
|
|
18b58424ec | ||
|
|
10858ea1dc | ||
|
|
6f8c7a66c8 | ||
|
|
a371390d6c | ||
|
|
a316766ad7 | ||
|
|
a9cc19f530 | ||
|
|
881a151d30 | ||
|
|
785c4caa67 | ||
|
|
4403bc67a1 | ||
|
|
b237113311 |
16
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
16
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -8,13 +8,13 @@ body:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
|
||||
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
@@ -42,20 +42,22 @@ body:
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks.
|
||||
placeholder: Having detailed steps helps us reproduce the bug.
|
||||
placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: What were you expecting?
|
||||
description: Describe what you expected to happen.
|
||||
placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here.
|
||||
validations:
|
||||
required: false
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: What happened instead?
|
||||
description: Describe what actually happened.
|
||||
placeholder: What happened instead? Please do not copy and paste the steps to reproduce here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
8
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,11 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: "\U0001F4A1 Model Providers & Plugins"
|
||||
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
||||
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
||||
- name: "\U0001F4AC Documentation Issues"
|
||||
url: "https://github.com/langgenius/dify-docs/issues/new"
|
||||
about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue.
|
||||
- name: "\U0001F4E7 Discussions"
|
||||
url: https://github.com/langgenius/dify/discussions/categories/general
|
||||
about: General discussions and request help from the community
|
||||
about: General discussions and seek help from the community
|
||||
|
||||
24
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
24
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@@ -1,24 +0,0 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- documentation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of requested docs changes
|
||||
placeholder: Briefly describe which document needs to be corrected and why.
|
||||
validations:
|
||||
required: true
|
||||
6
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
6
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -8,11 +8,11 @@ body:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
|
||||
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
55
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
55
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@@ -1,55 +0,0 @@
|
||||
name: "🌐 Localization/Translation issue"
|
||||
description: Report incorrect translations. [please use English :)]
|
||||
labels:
|
||||
- translation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
description: Hover over system tray icon or look at Settings
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: Utility with translation issue
|
||||
placeholder: Some area
|
||||
description: Please input here the utility with the translation issue
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: 🌐 Language affected
|
||||
placeholder: "German"
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual phrase(s)
|
||||
placeholder: What is there? Please include a screenshot as that is extremely helpful.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected phrase(s)
|
||||
placeholder: What was expected?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ℹ Why is the current translation wrong
|
||||
placeholder: Why do you feel this is incorrect?
|
||||
validations:
|
||||
required: true
|
||||
@@ -54,7 +54,7 @@
|
||||
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
|
||||
## Quick start
|
||||
|
||||
@@ -65,7 +65,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
</br>
|
||||
|
||||
The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
```bash
|
||||
cd dify
|
||||
@@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib
|
||||
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Using Terraform for Deployment
|
||||
|
||||
@@ -261,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media
|
||||
|
||||
## Security disclosure
|
||||
|
||||
To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.
|
||||
To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer.
|
||||
|
||||
## License
|
||||
|
||||
This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions.
|
||||
This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions.
|
||||
|
||||
@@ -188,6 +188,7 @@ docker compose up -d
|
||||
- [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### استخدام Terraform للتوزيع
|
||||
|
||||
|
||||
@@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
||||
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
|
||||
#### টেরাফর্ম ব্যবহার করে ডিপ্লয়
|
||||
|
||||
|
||||
@@ -194,9 +194,9 @@ docker compose up -d
|
||||
|
||||
如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。
|
||||
|
||||
#### 使用 Helm Chart 部署
|
||||
#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署
|
||||
|
||||
使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。
|
||||
使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
@@ -204,6 +204,10 @@ docker compose up -d
|
||||
- [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
|
||||
- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
|
||||
|
||||
#### 使用 Terraform 部署
|
||||
|
||||
使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台
|
||||
|
||||
@@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de
|
||||
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Terraform für die Bereitstellung verwenden
|
||||
|
||||
|
||||
@@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop
|
||||
- [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Uso de Terraform para el despliegue
|
||||
|
||||
|
||||
@@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau
|
||||
- [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Utilisation de Terraform pour le déploiement
|
||||
|
||||
|
||||
@@ -202,6 +202,7 @@ docker compose up -d
|
||||
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Terraformを使用したデプロイ
|
||||
|
||||
|
||||
@@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib
|
||||
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Terraform atorlugu pilersitsineq
|
||||
|
||||
|
||||
@@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
||||
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Terraform을 사용한 배포
|
||||
|
||||
|
||||
@@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts]
|
||||
- [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts)
|
||||
- [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Usando o Terraform para Implantação
|
||||
|
||||
|
||||
@@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Uporaba Terraform za uvajanje
|
||||
|
||||
|
||||
@@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify'
|
||||
- [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Dağıtım için Terraform Kullanımı
|
||||
|
||||
|
||||
@@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
||||
|
||||
如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。
|
||||
|
||||
如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。
|
||||
如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。
|
||||
|
||||
- [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
### 使用 Terraform 進行部署
|
||||
|
||||
|
||||
@@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có
|
||||
- [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes)
|
||||
- [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s)
|
||||
- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
|
||||
|
||||
#### Sử dụng Terraform để Triển khai
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@ APP_WEB_URL=http://127.0.0.1:3000
|
||||
# Files URL
|
||||
FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
|
||||
# Set this to the internal Docker service URL for proper plugin file access.
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
@@ -444,6 +449,19 @@ MAX_VARIABLE_SIZE=204800
|
||||
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||
|
||||
# Repository configuration
|
||||
# Core workflow execution repository implementation
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
# Core workflow node execution repository implementation
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow run repository implementation
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
@@ -237,6 +237,13 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
INTERNAL_FILES_URL: str = Field(
|
||||
description="Internal base URL for file access within Docker network,"
|
||||
" used for plugin daemon and internal service communication."
|
||||
" Falls back to FILES_URL if not specified.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@@ -530,6 +537,33 @@ class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class RepositoryConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for repository implementations
|
||||
"""
|
||||
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowExecution. Specify as a module path",
|
||||
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
|
||||
)
|
||||
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
|
||||
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
|
||||
"Specify as a module path",
|
||||
default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
API_WORKFLOW_RUN_REPOSITORY: str = Field(
|
||||
description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path",
|
||||
default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
@@ -896,6 +930,7 @@ class FeatureConfig(
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
|
||||
@@ -162,6 +162,11 @@ class DatabaseConfig(BaseSettings):
|
||||
default=3600,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_USE_LIFO: bool = Field(
|
||||
description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING: bool = Field(
|
||||
description="If True, enables connection pool pre-ping feature to check connections.",
|
||||
default=False,
|
||||
@@ -199,6 +204,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ from .app import (
|
||||
conversation,
|
||||
conversation_variables,
|
||||
generator,
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
ops_trace,
|
||||
|
||||
@@ -151,6 +151,7 @@ class AppApi(Resource):
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
119
api/controllers/console/app/mcp_server.py
Normal file
119
api/controllers/console/app/mcp_server.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
description = args.get("description")
|
||||
if not description:
|
||||
description = app_model.description or ""
|
||||
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=description,
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
description = args.get("description")
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
server.description = app_model.description or ""
|
||||
else:
|
||||
server.description = description
|
||||
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.filter(AppMCPServer.id == server_id)
|
||||
.filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
@@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models import AppMode, Message
|
||||
|
||||
|
||||
class DailyMessageStatistic(Resource):
|
||||
@@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource):
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
sa.func.date(
|
||||
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
).label("date"),
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
rs = conn.execute(stmt, {"tz": account.timezone})
|
||||
for row in rs:
|
||||
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@@ -68,13 +68,18 @@ def _create_pagination_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
@@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
@@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
|
||||
@@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.CHANNEL:
|
||||
raise AppNotFoundError()
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import send_file
|
||||
from flask import redirect, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,17 +10,34 @@ from werkzeug.exceptions import Forbidden
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -34,7 +52,7 @@ class ToolProviderListApi(Resource):
|
||||
req.add_argument(
|
||||
"type",
|
||||
type=str,
|
||||
choices=["builtin", "model", "api", "workflow"],
|
||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="args",
|
||||
@@ -613,6 +631,166 @@ class ToolLabelsApi(Resource):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
)
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
MCPToolManageService.update_mcp_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
tenant_id = current_user.current_tenant_id
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
try:
|
||||
with MCPClient(
|
||||
provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
for_list=True,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials=provider.decrypted_credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
except MCPAuthError:
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
||||
except MCPError as e:
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials={},
|
||||
authed=False,
|
||||
)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
user = current_user
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
class ToolMCPListAllApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
@@ -647,8 +825,15 @@ api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provid
|
||||
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
|
||||
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
|
||||
|
||||
# mcp tool provider
|
||||
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
|
||||
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
|
||||
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
|
||||
|
||||
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
|
||||
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
|
||||
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
|
||||
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
|
||||
|
||||
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")
|
||||
|
||||
8
api/controllers/mcp/__init__.py
Normal file
8
api/controllers/mcp/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from . import mcp
|
||||
104
api/controllers/mcp/mcp.py
Normal file
104
api/controllers/mcp/mcp.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import api
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPAppApi(Resource):
|
||||
def post(self, server_code):
|
||||
def int_or_str(value):
|
||||
if isinstance(value, (int, str)):
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("jsonrpc", type=str, required=True, location="json")
|
||||
parser.add_argument("method", type=str, required=True, location="json")
|
||||
parser.add_argument("params", type=dict, required=False, location="json")
|
||||
parser.add_argument("id", type=int_or_str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
request_id = args.get("id")
|
||||
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from dateutil.parser import isoparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
@@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource):
|
||||
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
# Use repository to get workflow run
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
import uuid
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
@@ -161,10 +163,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
@@ -254,10 +260,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
prompt_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
@@ -409,12 +419,15 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
messages = (
|
||||
(
|
||||
db.session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.message.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class AgentStrategyEntity(BaseModel):
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
features: Optional[list[AgentFeature]] = None
|
||||
|
||||
meta_version: Optional[str] = None
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@@ -15,10 +15,12 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
|
||||
tenant_id: str
|
||||
declaration: AgentStrategyEntity
|
||||
meta_version: str | None = None
|
||||
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity, meta_version: str | None):
|
||||
self.tenant_id = tenant_id
|
||||
self.declaration = declaration
|
||||
self.meta_version = meta_version
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
@@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
@@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
|
||||
@@ -16,9 +16,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationError
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
user_id = None
|
||||
user_id: str | None = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
@@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
|
||||
}
|
||||
system_inputs = SystemVariable(
|
||||
query=query,
|
||||
files=files,
|
||||
conversation_id=self.conversation.id,
|
||||
user_id=user_id,
|
||||
dialogue_count=self._dialogue_count,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
||||
@@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
@@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_system_variables=SystemVariable(
|
||||
query=message.query,
|
||||
files=application_generate_entity.files,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_session_id,
|
||||
dialogue_count=dialogue_count,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
),
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
|
||||
@@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
@@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
}
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -55,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@@ -68,7 +67,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -109,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
|
||||
},
|
||||
workflow_system_variables=SystemVariable(
|
||||
files=application_generate_entity.files,
|
||||
user_id=user_session_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
),
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
@@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
assert workflow_run is not None
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
@@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
|
||||
workflow_app_log.workflow_id = workflow_execution.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_execution.id_
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
|
||||
@@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
@@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ class CommonParameterType(StrEnum):
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
# MCP object and array type parameters
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class AppSelectorScope(StrEnum):
|
||||
|
||||
@@ -21,7 +21,9 @@ def get_signed_file_url(upload_file_id: str) -> str:
|
||||
|
||||
|
||||
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/upload/for-plugin"
|
||||
# Plugin access should use internal URL for Docker network communication
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
url = f"{base_url}/files/upload/for-plugin"
|
||||
|
||||
if user_id is None:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
@@ -5,6 +5,8 @@ from base64 import b64encode
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.utils import SegmentJSONEncoder
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
_code_placeholder: str = "{{code}}"
|
||||
@@ -43,17 +45,13 @@ class TemplateTransformer(ABC):
|
||||
result_str = cls.extract_result_str_from_response(response)
|
||||
result = json.loads(result_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
|
||||
raise ValueError(f"Failed to parse JSON response: {str(e)}.")
|
||||
except ValueError as e:
|
||||
# Re-raise ValueError from extract_result_str_from_response
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
|
||||
|
||||
# Check if the result contains an error
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
raise ValueError(f"JavaScript execution error: {result['error']}")
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
@@ -95,7 +93,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
|
||||
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
0
api/core/mcp/__init__.py
Normal file
0
api/core/mcp/__init__.py
Normal file
342
api/core/mcp/auth/auth_flow.py
Normal file
342
api/core/mcp/auth/auth_flow.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
|
||||
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
state_key = secrets.token_urlsafe(32)
|
||||
|
||||
# Store the state data in Redis with expiration
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
|
||||
|
||||
return state_key
|
||||
|
||||
|
||||
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
|
||||
# Get state data from Redis
|
||||
state_data = redis_client.get(redis_key)
|
||||
|
||||
if not state_data:
|
||||
raise ValueError("State parameter has expired or does not exist")
|
||||
|
||||
# Delete the state data from Redis immediately after retrieval to prevent reuse
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
try:
|
||||
# Parse and validate the state data
|
||||
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
||||
|
||||
return oauth_state
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
tokens = exchange_authorization(
|
||||
full_state_data.server_url,
|
||||
full_state_data.metadata,
|
||||
full_state_data.client_information,
|
||||
authorization_code,
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except requests.RequestException as e:
|
||||
if isinstance(e, requests.ConnectionError):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
|
||||
|
||||
def start_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
if metadata:
|
||||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
# Prepare state data with all necessary information
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
server_url=server_url,
|
||||
metadata=metadata,
|
||||
client_information=client_information,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
|
||||
# Store state data in Redis and generate secure state key
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
|
||||
params = {
|
||||
"response_type": response_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"redirect_uri": redirect_url,
|
||||
"state": state_key,
|
||||
}
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
authorization_code: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code": authorization_code,
|
||||
"code_verifier": code_verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def register_client(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_metadata: OAuthClientMetadata,
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Performs OAuth 2.0 Dynamic Client Registration."""
|
||||
if metadata:
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||
registration_url = metadata.registration_endpoint
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = requests.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if not response.ok:
|
||||
response.raise_for_status()
|
||||
return OAuthClientInformationFull.model_validate(response.json())
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
|
||||
try:
|
||||
# Retrieve state data from Redis using state key
|
||||
full_state_data = _retrieve_redis_state(state_param)
|
||||
|
||||
code_verifier = full_state_data.code_verifier
|
||||
redirect_uri = full_state_data.redirect_uri
|
||||
|
||||
if not code_verifier or not redirect_uri:
|
||||
raise ValueError("Missing code_verifier or redirect_uri in state data")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid state parameter: {e}")
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
81
api/core/mcp/auth/auth_provider.py
Normal file
81
api/core/mcp/auth/auth_provider.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> Optional[OAuthTokens]:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens) -> None:
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str) -> None:
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
361
api/core/mcp/client/sse_client.py
Normal file
361
api/core/mcp/client/sse_client.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self._endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self._exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> None:
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
||||
Args:
|
||||
endpoint_url: The endpoint URL to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
url_parsed = urlparse(self.url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||
|
||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
|
||||
"""Handle an 'endpoint' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
endpoint_url = urljoin(self.url, sse_data)
|
||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(_StatusError(ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(_StatusReady(endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
||||
"""Handle a 'message' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
read_queue: Queue to put parsed messages.
|
||||
"""
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
sse: The SSE event object.
|
||||
read_queue: Queue for message events.
|
||||
status_queue: Queue for status events.
|
||||
"""
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
self._handle_endpoint_event(sse.data, status_queue)
|
||||
case "message":
|
||||
self._handle_message_event(sse.data, read_queue)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
event_source: The SSE event source.
|
||||
read_queue: Queue to put received messages.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, read_queue, status_queue)
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
|
||||
"""Send a single message to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send to.
|
||||
message: The message to send.
|
||||
"""
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
|
||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
||||
"""Handle writing messages to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send messages to.
|
||||
write_queue: Queue to read messages from.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
|
||||
self._send_message(client, endpoint_url, message)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"Post writer shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
|
||||
"""Wait for the endpoint URL from the status queue.
|
||||
|
||||
Args:
|
||||
status_queue: Queue to read status from.
|
||||
|
||||
Returns:
|
||||
The endpoint URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status._endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status._exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
Args:
|
||||
executor: Thread pool executor.
|
||||
client: HTTP client.
|
||||
event_source: SSE event source.
|
||||
|
||||
Returns:
|
||||
Tuple of (read_queue, write_queue).
|
||||
"""
|
||||
read_queue: ReadQueue = queue.Queue()
|
||||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
# Wait for endpoint URL
|
||||
endpoint_url = self._wait_for_endpoint(status_queue)
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
# Start post writer thread
|
||||
executor.submit(self.post_writer, client, endpoint_url, write_queue)
|
||||
|
||||
return read_queue, write_queue
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_queue, write_queue) for message communication.
|
||||
"""
|
||||
transport = SSETransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
|
||||
"""
|
||||
Send a message to the server using the provided HTTP client.
|
||||
|
||||
Args:
|
||||
http_client: The HTTP client to use for sending
|
||||
endpoint_url: The endpoint URL to send the message to
|
||||
session_message: The message to send
|
||||
"""
|
||||
try:
|
||||
response = http_client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
||||
def read_messages(
|
||||
sse_client: SSEClient,
|
||||
) -> Generator[SessionMessage | Exception, None, None]:
|
||||
"""
|
||||
Read messages from the SSE client.
|
||||
|
||||
Args:
|
||||
sse_client: The SSE client to read from
|
||||
|
||||
Yields:
|
||||
SessionMessage or Exception for each event received
|
||||
"""
|
||||
try:
|
||||
for sse in sse_client.events():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
yield SessionMessage(message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
yield exc
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error reading SSE messages")
|
||||
yield exc
|
||||
476
api/core/mcp/client/streamable_client.py
Normal file
476
api/core/mcp/client/streamable_client.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception | None
|
||||
# Queue types with clearer names for their roles
|
||||
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
|
||||
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "Accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.Client
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: timedelta
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], None] | None = None,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
# Put message in queue that goes to client
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc:
|
||||
# Put exception in queue that goes to client
|
||||
server_to_client_queue.put(exc)
|
||||
return False
|
||||
elif sse.event == "ping":
|
||||
logger.debug("Received ping event")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
def handle_get_stream(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
||||
|
||||
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token")
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = response.read()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
server_to_client_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
server_to_client_queue.put(ValueError(error_msg))
|
||||
|
||||
def _send_session_terminated_error(
|
||||
self,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated by server"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
def post_writer(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
client_to_server_queue: ClientToServerQueue,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
start_get_stream: Callable[[], None],
|
||||
) -> None:
|
||||
"""Handle writing requests to the server.
|
||||
|
||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||
Responses are written to the server_to_client_queue.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
break
|
||||
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
self._handle_resumption_request(ctx)
|
||||
else:
|
||||
self._handle_post_request(ctx)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client) -> None:
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
response = client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
ServerToClientQueue, # Queue for receiving messages FROM server
|
||||
ClientToServerQueue, # Queue for sending messages TO server
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- server_to_client_queue: Queue for reading messages FROM the server
|
||||
- client_to_server_queue: Queue for sending messages TO the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
# Create queues with clear directional meaning
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream() -> None:
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
19
api/core/mcp/entities.py
Normal file
19
api/core/mcp/entities.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
10
api/core/mcp/error.py
Normal file
10
api/core/mcp/error.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
150
api/core/mcp/mcp_client.py
Normal file
150
api/core/mcp/mcp_client.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
||||
self._session_context: Optional[ClientSession] = None
|
||||
self.exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
|
||||
):
|
||||
self.cleanup()
|
||||
|
||||
def _initialize(
|
||||
self,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
||||
"mcp": streamablehttp_client,
|
||||
"sse": sse_client,
|
||||
}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
try:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
except KeyError:
|
||||
try:
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
||||
):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if self._streams_context is None:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self.exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self.exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
try:
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to authenticate: {e}")
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
raise
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
if not self._initialized or not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
tools = response.tools
|
||||
return tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
||||
"""Call a tool"""
|
||||
if not self._initialized or not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
228
api/core/mcp/server/streamable_http.py
Normal file
228
api/core/mcp/server/streamable_http.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result | str):
|
||||
if isinstance(response, str):
|
||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
||||
yield sse_content
|
||||
return
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
||||
return create_mcp_error_response(request_id, code, message, data)
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
types.InitializedNotification: self.handle_notification,
|
||||
types.PingRequest: self.handle_ping,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
|
||||
def handle_notification(self):
|
||||
return "ping"
|
||||
|
||||
def handle_ping(self):
|
||||
return types.EmptyResult()
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.app.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments
|
||||
if not args:
|
||||
raise ValueError("No arguments provided")
|
||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
||||
args = {"query": "", "inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(
|
||||
self.app,
|
||||
self.end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
answer = ""
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
for item in response.generator:
|
||||
data = item
|
||||
if isinstance(data, str) and data.startswith("data: "):
|
||||
try:
|
||||
json_str = data[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(response, Mapping):
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
parameters[item.variable] = {}
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
try:
|
||||
description = self.mcp_server.parameters_dict[item.variable]
|
||||
except KeyError:
|
||||
description = ""
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
415
api/core/mcp/session/base_session.py
Normal file
415
api/core/mcp/session/base_session.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MessageMetadata,
|
||||
RequestId,
|
||||
RequestParams,
|
||||
ServerMessageMetadata,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: Any
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
|
||||
def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
self._completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is a context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
_receive_notification_type: type[ReceiveNotificationT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._exit_stack = ExitStack()
|
||||
# Initialize executor and future to None for proper cleanup checks
|
||||
self._executor: ThreadPoolExecutor | None = None
|
||||
self._receiver_future: Future | None = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
# The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
|
||||
# ensures no unnecessary threads are created.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||
return self
|
||||
|
||||
def check_receiver_status(self) -> None:
|
||||
"""`check_receiver_status` ensures that any exceptions raised during the
|
||||
execution of `_receive_loop` are retrieved and propagated."""
|
||||
if self._receiver_future and self._receiver_future.done():
|
||||
self._receiver_future.result()
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self._read_stream.put(None)
|
||||
self._write_stream.put(None)
|
||||
|
||||
# Wait for the receiver loop to finish
|
||||
if self._receiver_future:
|
||||
try:
|
||||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||
except TimeoutError:
|
||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||
pass
|
||||
|
||||
# Shutdown the executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error. If a request read timeout is provided, it
|
||||
will take precedence over the session read timeout.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = float(request_read_timeout_seconds.total_seconds())
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = float(self._session_read_timeout_seconds.total_seconds())
|
||||
while True:
|
||||
try:
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
except queue.Empty:
|
||||
self.check_receiver_status()
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(
|
||||
code=500,
|
||||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
|
||||
def send_notification(
|
||||
self,
|
||||
notification: SendNotificationT,
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||
)
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
self._write_stream.put(session_message)
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _receive_loop(self) -> None:
|
||||
"""
|
||||
Main message processing loop.
|
||||
In a real synchronous implementation, this would likely run in a separate thread.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder._completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification)
|
||||
self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
|
||||
else: # Response or error
|
||||
response_queue = self._response_streams.get(message.message.root.id)
|
||||
if response_queue is not None:
|
||||
response_queue.put(message.message.root)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
365
api/core/mcp/session/client_session.py
Normal file
365
api/core/mcp/session/client_session.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise ValueError(str(message))
|
||||
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
||||
pass
|
||||
|
||||
|
||||
def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream,
|
||||
write_stream,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self.send_notification(
|
||||
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
list_roots_response = self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(list_roots_response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
1217
api/core/mcp/types.py
Normal file
1217
api/core/mcp/types.py
Normal file
File diff suppressed because it is too large
Load Diff
114
api/core/mcp/utils.py
Normal file
114
api/core/mcp/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
mounts=proxy_mounts,
|
||||
)
|
||||
else:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, **kwargs):
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
|
||||
"""Create MCP error response"""
|
||||
error_data = ErrorData(code=code, message=message, data=data)
|
||||
json_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id or 1,
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
yield sse_content
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
||||
@@ -36,20 +42,8 @@ class TokenBufferMemory:
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id,
|
||||
Message.parent_message_id,
|
||||
Message.answer_tokens,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
@@ -57,7 +51,9 @@ class TokenBufferMemory:
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
@@ -74,18 +70,20 @@ class TokenBufferMemory:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(
|
||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
if message.workflow_run_id:
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
)
|
||||
|
||||
if workflow_run and workflow_run.workflow:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
|
||||
@@ -123,6 +123,8 @@ class ProviderEntity(BaseModel):
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
icon_small_dark: Optional[I18nObject] = None
|
||||
icon_large_dark: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
|
||||
@@ -284,7 +284,8 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
return node_span
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||
@@ -306,7 +307,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
@@ -381,7 +382,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
@@ -415,7 +416,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
|
||||
@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
@@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
@@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
@@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
@@ -241,7 +241,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
||||
"name": node_type,
|
||||
"name": node_name,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
"end_time": finished_at,
|
||||
|
||||
@@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
@@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,6 +43,19 @@ class PluginParameterType(enum.StrEnum):
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = CommonParameterType.ARRAY.value
|
||||
OBJECT = CommonParameterType.OBJECT.value
|
||||
|
||||
|
||||
class MCPServerParameterType(enum.StrEnum):
|
||||
"""
|
||||
MCP server got complex parameter types
|
||||
"""
|
||||
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class PluginParameterAutoGenerate(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
@@ -138,6 +151,34 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
if value and not isinstance(value, list):
|
||||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case PluginParameterType.ARRAY:
|
||||
if not isinstance(value, list):
|
||||
# Try to parse JSON string for arrays
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, list):
|
||||
return parsed_value
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return [value]
|
||||
return value
|
||||
case PluginParameterType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
# Try to parse JSON string for objects
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, dict):
|
||||
return parsed_value
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return {}
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
except ValueError:
|
||||
|
||||
@@ -72,12 +72,14 @@ class PluginDeclaration(BaseModel):
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
version: Optional[str] = Field(default=None)
|
||||
|
||||
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon_dark: Optional[str] = Field(default=None)
|
||||
label: I18nObject
|
||||
category: PluginCategory
|
||||
created_at: datetime.datetime
|
||||
|
||||
@@ -53,6 +53,7 @@ class PluginAgentProviderEntity(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: AgentProviderEntityWithPlugin
|
||||
meta: PluginDeclaration.Meta
|
||||
|
||||
|
||||
class PluginBasicBooleanResponse(BaseModel):
|
||||
|
||||
@@ -32,7 +32,7 @@ class RequestInvokeTool(BaseModel):
|
||||
Request to invoke a tool
|
||||
"""
|
||||
|
||||
tool_type: Literal["builtin", "workflow", "api"]
|
||||
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
|
||||
@@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient):
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
PluginListResponse,
|
||||
params={"page": 1, "page_size": 256},
|
||||
params={"page": 1, "page_size": 256, "response_type": "paged"},
|
||||
)
|
||||
return result.list
|
||||
|
||||
@@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient):
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
PluginListResponse,
|
||||
params={"page": page, "page_size": page_size},
|
||||
params={"page": page, "page_size": page_size, "response_type": "paged"},
|
||||
)
|
||||
|
||||
def upload_pkg(
|
||||
|
||||
@@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||
if self.with_variable_tmpl:
|
||||
vp = VariablePool()
|
||||
vp = VariablePool.empty()
|
||||
for k, v in inputs.items():
|
||||
if k.startswith("#"):
|
||||
vp.add(k[1:-1].split("."), v)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
|
||||
from constants import UUID_NIL
|
||||
from models import Message
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[Any]):
|
||||
thread_messages = []
|
||||
def extract_thread_messages(messages: Sequence[Message]):
|
||||
thread_messages: list[Message] = []
|
||||
next_message = None
|
||||
|
||||
for message in messages:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
@@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int:
|
||||
Get the number of thread messages based on the parent message id.
|
||||
"""
|
||||
# Fetch all messages related to the conversation
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.parent_message_id,
|
||||
Message.answer,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
|
||||
|
||||
messages = query.all()
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# Extract thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import load_only
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
@@ -144,7 +144,8 @@ class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
|
||||
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
with Session(db.engine) as session:
|
||||
return session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
|
||||
@@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
replication_factor: int = 1
|
||||
write_consistency_factor: int = 1
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
@@ -127,6 +128,7 @@ class QdrantVector(BaseVector):
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
replication_factor=self._client_config.replication_factor,
|
||||
write_consistency_factor=self._client_config.write_consistency_factor,
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Optional
|
||||
|
||||
import tablestore # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -50,6 +51,29 @@ class TableStoreVector(BaseVector):
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
||||
|
||||
def create_collection(self, embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
docs = []
|
||||
request = BatchGetRowRequest()
|
||||
columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
|
||||
rows_to_get = [[("id", _id)] for _id in ids]
|
||||
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
|
||||
|
||||
result = self._tablestore_client.batch_get_row(request)
|
||||
table_result = result.get_result_by_table(self._table_name)
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, t in item.row.attribute_columns}
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TABLESTORE
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -13,6 +15,8 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Whitelist
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractVectorFactory(ABC):
|
||||
@abstractmethod
|
||||
@@ -173,8 +177,20 @@ class Vector:
|
||||
|
||||
def create(self, texts: Optional[list] = None, **kwargs):
|
||||
if texts:
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
||||
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
||||
start = time.time()
|
||||
logger.info(f"start embedding {len(texts)} texts {start}")
|
||||
batch_size = 1000
|
||||
total_batches = len(texts) + batch_size - 1
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
|
||||
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
|
||||
logger.info(
|
||||
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
|
||||
)
|
||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get("duplicate_check", False):
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -598,7 +599,8 @@ class DatasetRetrieval:
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
@@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
||||
224
api/core/repositories/factory.py
Normal file
224
api/core/repositories/factory.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Repository factory for dynamically creating repository instances based on configuration.
|
||||
|
||||
This module provides a Django-like settings system for repository implementations,
|
||||
allowing users to configure different repository backends through string paths.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Protocol, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RepositoryImportError(Exception):
|
||||
"""Raised when a repository implementation cannot be imported or instantiated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DifyCoreRepositoryFactory:
|
||||
"""
|
||||
Factory for creating repository instances based on configuration.
|
||||
|
||||
This factory supports Django-like settings where repository implementations
|
||||
are specified as module paths (e.g., 'module.submodule.ClassName').
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _import_class(class_path: str) -> type:
|
||||
"""
|
||||
Import a class from a module path string.
|
||||
|
||||
Args:
|
||||
class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
|
||||
|
||||
Returns:
|
||||
The imported class
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the class cannot be imported
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
repo_class = getattr(module, class_name)
|
||||
assert isinstance(repo_class, type)
|
||||
return repo_class
|
||||
except (ValueError, ImportError, AttributeError) as e:
|
||||
raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
|
||||
"""
|
||||
Validate that a class implements the expected repository interface.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
expected_interface: The expected interface/protocol
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the class doesn't implement the interface
|
||||
"""
|
||||
# Check if the class has all required methods from the protocol
|
||||
required_methods = [
|
||||
method
|
||||
for method in dir(expected_interface)
|
||||
if not method.startswith("_") and callable(getattr(expected_interface, method, None))
|
||||
]
|
||||
|
||||
missing_methods = []
|
||||
for method_name in required_methods:
|
||||
if not hasattr(repository_class, method_name):
|
||||
missing_methods.append(method_name)
|
||||
|
||||
if missing_methods:
|
||||
raise RepositoryImportError(
|
||||
f"Repository class '{repository_class.__name__}' does not implement required methods "
|
||||
f"{missing_methods} from interface '{expected_interface.__name__}'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
|
||||
"""
|
||||
Validate that a repository class constructor accepts required parameters.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
required_params: List of required parameter names
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the constructor doesn't accept required parameters
|
||||
"""
|
||||
|
||||
try:
|
||||
# MyPy may flag the line below with the following error:
|
||||
#
|
||||
# > Accessing "__init__" on an instance is unsound, since
|
||||
# > instance.__init__ could be from an incompatible subclass.
|
||||
#
|
||||
# Despite this, we need to ensure that the constructor of `repository_class`
|
||||
# has a compatible signature.
|
||||
signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
|
||||
param_names = list(signature.parameters.keys())
|
||||
|
||||
# Remove 'self' parameter
|
||||
if "self" in param_names:
|
||||
param_names.remove("self")
|
||||
|
||||
missing_params = [param for param in required_params if param not in param_names]
|
||||
if missing_params:
|
||||
raise RepositoryImportError(
|
||||
f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
|
||||
f"{missing_params}. Expected parameters: {required_params}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise RepositoryImportError(
|
||||
f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def create_workflow_execution_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
) -> WorkflowExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowExecutionRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
|
||||
Returns:
|
||||
Configured WorkflowExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create WorkflowExecutionRepository")
|
||||
raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
|
||||
Returns:
|
||||
Configured WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create WorkflowNodeExecutionRepository")
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
||||
@@ -39,19 +39,22 @@ class ApiToolProviderController(ToolProviderController):
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
|
||||
ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")),
|
||||
ProviderConfig.Option(
|
||||
value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数")
|
||||
),
|
||||
],
|
||||
default="none",
|
||||
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
|
||||
)
|
||||
]
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
if auth_type == ApiProviderAuthType.API_KEY_HEADER:
|
||||
credentials_schema = [
|
||||
*credentials_schema,
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
required=False,
|
||||
default="api_key",
|
||||
default="Authorization",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
|
||||
),
|
||||
@@ -74,6 +77,25 @@ class ApiToolProviderController(ToolProviderController):
|
||||
],
|
||||
),
|
||||
]
|
||||
elif auth_type == ApiProviderAuthType.API_KEY_QUERY:
|
||||
credentials_schema = [
|
||||
*credentials_schema,
|
||||
ProviderConfig(
|
||||
name="api_key_query_param",
|
||||
required=False,
|
||||
default="key",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(
|
||||
en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称"
|
||||
),
|
||||
),
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
|
||||
),
|
||||
]
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
|
||||
|
||||
@@ -78,8 +78,8 @@ class ApiTool(Tool):
|
||||
if "auth_type" not in credentials:
|
||||
raise ToolProviderCredentialValidationError("Missing auth_type")
|
||||
|
||||
if credentials["auth_type"] == "api_key":
|
||||
api_key_header = "api_key"
|
||||
if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility:
|
||||
api_key_header = "Authorization"
|
||||
|
||||
if "api_key_header" in credentials:
|
||||
api_key_header = credentials["api_key_header"]
|
||||
@@ -100,6 +100,11 @@ class ApiTool(Tool):
|
||||
|
||||
headers[api_key_header] = credentials["api_key_value"]
|
||||
|
||||
elif credentials["auth_type"] == "api_key_query":
|
||||
# For query parameter authentication, we don't add anything to headers
|
||||
# The query parameter will be added in do_http_request method
|
||||
pass
|
||||
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
@@ -154,6 +159,15 @@ class ApiTool(Tool):
|
||||
cookies = {}
|
||||
files = []
|
||||
|
||||
# Add API key to query parameters if auth_type is api_key_query
|
||||
if self.runtime and self.runtime.credentials:
|
||||
credentials = self.runtime.credentials
|
||||
if credentials.get("auth_type") == "api_key_query":
|
||||
api_key_query_param = credentials.get("api_key_query_param", "key")
|
||||
api_key_value = credentials.get("api_key_value")
|
||||
if api_key_value:
|
||||
params[api_key_query_param] = api_key_value
|
||||
|
||||
# check parameters
|
||||
for parameter in self.api_bundle.openapi.get("parameters", []):
|
||||
value = self.get_parameter_value(parameter, parameters)
|
||||
@@ -213,7 +227,8 @@ class ApiTool(Tool):
|
||||
elif "default" in property:
|
||||
body[name] = property["default"]
|
||||
else:
|
||||
body[name] = None
|
||||
# omit optional parameters that weren't provided, instead of setting them to None
|
||||
pass
|
||||
break
|
||||
|
||||
# replace path parameters
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -18,7 +19,7 @@ class ToolApiEntity(BaseModel):
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
|
||||
|
||||
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
@@ -27,6 +28,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool")
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Optional[dict] = None
|
||||
@@ -37,6 +39,10 @@ class ToolProviderApiEntity(BaseModel):
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
# MCP
|
||||
server_url: Optional[str] = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@@ -52,8 +58,13 @@ class ToolProviderApiEntity(BaseModel):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
if parameter.get("input_schema") is None:
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP.value:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
@@ -62,6 +73,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"icon_dark": self.icon_dark,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
@@ -69,4 +81,9 @@ class ToolProviderApiEntity(BaseModel):
|
||||
"allow_delete": self.allow_delete,
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
**optional_fields,
|
||||
}
|
||||
|
||||
def optional_field(self, key: str, value: Any) -> dict:
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_seriali
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
MCPServerParameterType,
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
@@ -49,6 +50,7 @@ class ToolProviderType(enum.StrEnum):
|
||||
API = "api"
|
||||
APP = "app"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = "mcp"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
@@ -94,7 +96,8 @@ class ApiProviderAuthType(Enum):
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
API_KEY_HEADER = "api_key_header"
|
||||
API_KEY_QUERY = "api_key_query"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
@@ -242,6 +245,10 @@ class ToolParameter(PluginParameter):
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = MCPServerParameterType.ARRAY.value
|
||||
OBJECT = MCPServerParameterType.OBJECT.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
@@ -260,6 +267,8 @@ class ToolParameter(PluginParameter):
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
# MCP object and array type parameters use this field to store the schema
|
||||
input_schema: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
@@ -309,6 +318,7 @@ class ToolProviderIdentity(BaseModel):
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
|
||||
130
api/core/tools/mcp_tool/provider.py
Normal file
130
api/core/tools/mcp_tool/provider.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntityWithPlugin,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
|
||||
super().__init__(entity)
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@classmethod
|
||||
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.server_identifier,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=None,
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
92
api/core/tools/mcp_tool/tool.py
Normal file
92
api/core/tools/mcp_tool/tool.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
server_url: str
|
||||
provider_id: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.runtime_parameters = None
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ToolInvokeError("Please auth the tool first") from e
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
try:
|
||||
content_json = json.loads(content.text)
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
for item in content_json:
|
||||
yield self.create_json_message(item)
|
||||
else:
|
||||
yield self.create_text_message(content.text)
|
||||
except json.JSONDecodeError:
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self.create_blob_message(
|
||||
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
in mcp tool invoke, if the parameter is empty, it will be set to None
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
@@ -9,9 +9,10 @@ from configs import dify_config
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@@ -35,9 +35,10 @@ class ToolFileManager:
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@@ -4,7 +4,7 @@ import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@@ -13,9 +13,13 @@ from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
@@ -49,7 +53,7 @@ from core.tools.utils.configuration import (
|
||||
)
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -156,7 +160,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@@ -292,6 +296,8 @@ class ToolManager:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
@@ -302,6 +308,7 @@ class ToolManager:
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@@ -316,24 +323,9 @@ class ToolManager:
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if (
|
||||
parameter.type
|
||||
in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
and parameter.required
|
||||
):
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
runtime_parameters = cls._convert_tool_parameters_type(
|
||||
parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
|
||||
)
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
@@ -357,10 +349,12 @@ class ToolManager:
|
||||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
|
||||
tool_runtime = cls.get_tool_runtime(
|
||||
provider_type=workflow_tool.provider_type,
|
||||
provider_id=workflow_tool.provider_id,
|
||||
@@ -369,15 +363,11 @@ class ToolManager:
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
|
||||
for parameter in parameters:
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
runtime_parameters = cls._convert_tool_parameters_type(
|
||||
parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
|
||||
)
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
@@ -569,7 +559,7 @@ class ToolManager:
|
||||
|
||||
filters = []
|
||||
if not typ:
|
||||
filters.extend(["builtin", "api", "workflow"])
|
||||
filters.extend(["builtin", "api", "workflow", "mcp"])
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
@@ -663,6 +653,10 @@ class ToolManager:
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
@@ -690,14 +684,47 @@ class ToolManager:
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
auth_type = ApiProviderAuthType.NONE
|
||||
provider_auth_type = provider.credentials.get("auth_type")
|
||||
if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility
|
||||
auth_type = ApiProviderAuthType.API_KEY_HEADER
|
||||
elif provider_auth_type == "api_key_query":
|
||||
auth_type = ApiProviderAuthType.API_KEY_QUERY
|
||||
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider,
|
||||
ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
auth_type,
|
||||
)
|
||||
controller.load_bundled_tools(provider.tools)
|
||||
|
||||
return controller, provider.credentials
|
||||
|
||||
@classmethod
|
||||
def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider_id: the id of the provider
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
MCPToolProvider.server_identifier == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
controller = MCPToolProviderController._from_db(provider)
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
|
||||
"""
|
||||
@@ -725,9 +752,16 @@ class ToolManager:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
auth_type = ApiProviderAuthType.NONE
|
||||
credentials_auth_type = credentials.get("auth_type")
|
||||
if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
|
||||
auth_type = ApiProviderAuthType.API_KEY_HEADER
|
||||
elif credentials_auth_type == "api_key_query":
|
||||
auth_type = ApiProviderAuthType.API_KEY_QUERY
|
||||
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider_obj,
|
||||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
auth_type,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
@@ -826,6 +860,22 @@ class ToolManager:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mcp_provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
return mcp_provider.provider_icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def get_tool_icon(
|
||||
cls,
|
||||
@@ -863,8 +913,61 @@ class ToolManager:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional[VariablePool],
|
||||
tool_configurations: dict[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert tool parameters type
|
||||
"""
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.nodes.tool.exc import ToolParameterError
|
||||
|
||||
runtime_parameters = {}
|
||||
for parameter in parameters:
|
||||
if (
|
||||
parameter.type
|
||||
in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
and parameter.required
|
||||
and typ == "agent"
|
||||
):
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
if variable_pool:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
runtime_parameters[parameter.name] = parameter_value
|
||||
|
||||
else:
|
||||
value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
return runtime_parameters
|
||||
|
||||
|
||||
ToolManager.load_hardcoded_providers_cache()
|
||||
|
||||
@@ -72,21 +72,21 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
if use_cache:
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
@@ -105,7 +105,8 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
if use_cache:
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
@@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
path = "<root>"
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
|
||||
@@ -8,7 +8,12 @@ from flask_login import current_user
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.file import File
|
||||
|
||||
@@ -11,6 +11,11 @@ from .types import SegmentType
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
"""Segment is runtime type used during the execution of workflow.
|
||||
|
||||
Note: this class is abstract, you should use subclasses of this class instead.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
@@ -73,7 +78,7 @@ class StringSegment(Segment):
|
||||
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value_type: SegmentType = SegmentType.FLOAT
|
||||
value: float
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
@@ -92,7 +97,7 @@ class FloatSegment(Segment):
|
||||
|
||||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value_type: SegmentType = SegmentType.INTEGER
|
||||
value: int
|
||||
|
||||
|
||||
@@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
elif isinstance(v, dict):
|
||||
value_type = v.get("value_type")
|
||||
if value_type is None:
|
||||
return None
|
||||
try:
|
||||
seg_type = SegmentType(value_type)
|
||||
except ValueError:
|
||||
return None
|
||||
return seg_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Segment` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
# - `SegmentGroup`, which is not added to the variable pool.
|
||||
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
||||
SegmentUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringSegment, Tag(SegmentType.STRING)]
|
||||
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
|
||||
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements"""
|
||||
|
||||
# Skip element validation (only check array container)
|
||||
NONE = "none"
|
||||
|
||||
# Validate the first element (if array is non-empty)
|
||||
FIRST = "first"
|
||||
|
||||
# Validate all elements in the array.
|
||||
ALL = "all"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
STRING = "string"
|
||||
OBJECT = "object"
|
||||
SECRET = "secret"
|
||||
@@ -19,16 +38,141 @@ class SegmentType(StrEnum):
|
||||
|
||||
GROUP = "group"
|
||||
|
||||
def is_array_type(self):
|
||||
def is_array_type(self) -> bool:
|
||||
return self in _ARRAY_TYPES
|
||||
|
||||
@classmethod
|
||||
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
||||
"""
|
||||
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||
|
||||
Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
|
||||
For example, this may occur if the input is a generic Python object of type `object`.
|
||||
"""
|
||||
|
||||
if isinstance(value, list):
|
||||
elem_types: set[SegmentType] = set()
|
||||
for i in value:
|
||||
segment_type = cls.infer_segment_type(i)
|
||||
if segment_type is None:
|
||||
return None
|
||||
|
||||
elem_types.add(segment_type)
|
||||
|
||||
if len(elem_types) != 1:
|
||||
if elem_types.issubset(_NUMERICAL_TYPES):
|
||||
return SegmentType.ARRAY_NUMBER
|
||||
return SegmentType.ARRAY_ANY
|
||||
elif all(i.is_array_type() for i in elem_types):
|
||||
return SegmentType.ARRAY_ANY
|
||||
match elem_types.pop():
|
||||
case SegmentType.STRING:
|
||||
return SegmentType.ARRAY_STRING
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return SegmentType.ARRAY_NUMBER
|
||||
case SegmentType.OBJECT:
|
||||
return SegmentType.ARRAY_OBJECT
|
||||
case SegmentType.FILE:
|
||||
return SegmentType.ARRAY_FILE
|
||||
case SegmentType.NONE:
|
||||
return SegmentType.ARRAY_ANY
|
||||
case _:
|
||||
# This should be unreachable.
|
||||
raise ValueError(f"not supported value {value}")
|
||||
if value is None:
|
||||
return SegmentType.NONE
|
||||
elif isinstance(value, int) and not isinstance(value, bool):
|
||||
return SegmentType.INTEGER
|
||||
elif isinstance(value, float):
|
||||
return SegmentType.FLOAT
|
||||
elif isinstance(value, str):
|
||||
return SegmentType.STRING
|
||||
elif isinstance(value, dict):
|
||||
return SegmentType.OBJECT
|
||||
elif isinstance(value, File):
|
||||
return SegmentType.FILE
|
||||
elif isinstance(value, str):
|
||||
return SegmentType.STRING
|
||||
else:
|
||||
return None
|
||||
|
||||
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
# Skip element validation if array is empty
|
||||
if len(value) == 0:
|
||||
return True
|
||||
if self == SegmentType.ARRAY_ANY:
|
||||
return True
|
||||
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
|
||||
|
||||
if array_validation == ArrayValidation.NONE:
|
||||
return True
|
||||
elif array_validation == ArrayValidation.FIRST:
|
||||
return element_type.is_valid(value[0])
|
||||
else:
|
||||
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
"""
|
||||
Check if a value matches the segment type.
|
||||
Users of `SegmentType` should call this method, instead of using
|
||||
`isinstance` manually.
|
||||
|
||||
Args:
|
||||
value: The value to validate
|
||||
array_validation: Validation strategy for array types (ignored for non-array types)
|
||||
|
||||
Returns:
|
||||
True if the value matches the type under the given validation strategy
|
||||
"""
|
||||
if self.is_array_type():
|
||||
return self._validate_array(value, array_validation)
|
||||
elif self == SegmentType.NUMBER:
|
||||
return isinstance(value, (int, float))
|
||||
elif self == SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
elif self == SegmentType.OBJECT:
|
||||
return isinstance(value, dict)
|
||||
elif self == SegmentType.SECRET:
|
||||
return isinstance(value, str)
|
||||
elif self == SegmentType.FILE:
|
||||
return isinstance(value, File)
|
||||
elif self == SegmentType.NONE:
|
||||
return value is None
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
def exposed_type(self) -> "SegmentType":
|
||||
"""Returns the type exposed to the frontend.
|
||||
|
||||
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||
"""
|
||||
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
|
||||
return SegmentType.NUMBER
|
||||
return self
|
||||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have correpond element type.
|
||||
SegmentType.ARRAY_STRING: SegmentType.STRING,
|
||||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_FILE: SegmentType.FILE,
|
||||
}
|
||||
|
||||
_ARRAY_TYPES = frozenset(
|
||||
[
|
||||
list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
|
||||
+ [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
_NUMERICAL_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Annotated, TypeAlias, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
|
||||
from core.helper import encrypter
|
||||
|
||||
@@ -20,6 +20,7 @@ from .segments import (
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from .types import SegmentType
|
||||
|
||||
@@ -27,6 +28,10 @@ from .types import SegmentType
|
||||
class Variable(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
|
||||
It is mainly used to store segments and their selector in VariablePool.
|
||||
|
||||
Note: this class is abstract, you should use subclasses of this class instead.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
@@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Variable` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `VariableUnion` must inherit from the `Variable` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
VariableUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneVariable, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringVariable, Tag(SegmentType.STRING)]
|
||||
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
|
||||
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileVariable, Tag(SegmentType.FILE)]
|
||||
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@@ -232,14 +232,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
||||
@@ -250,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
from typing import Annotated, Any, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
@@ -23,31 +24,31 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
# TODO: This user inputs is not used for pool.
|
||||
|
||||
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description="User inputs",
|
||||
default_factory=dict,
|
||||
)
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
system_variables: SystemVariable = Field(
|
||||
description="System variables",
|
||||
default_factory=dict,
|
||||
)
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
conversation_variables: Sequence[Variable] = Field(
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
self._add_system_variables(self.system_variables)
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
@@ -83,8 +84,22 @@ class VariablePool(BaseModel):
|
||||
segment = variable_factory.build_segment(value)
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]][hash_key] = variable
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
||||
return selector[0], hash(tuple(selector[1:]))
|
||||
|
||||
def _has(self, selector: Sequence[str]) -> bool:
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
if key not in self.variable_dictionary:
|
||||
return False
|
||||
if hash_key not in self.variable_dictionary[key]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
@@ -102,8 +117,8 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
||||
|
||||
if value is None:
|
||||
selector, attr = selector[:-1], selector[-1]
|
||||
@@ -136,8 +151,9 @@ class VariablePool(BaseModel):
|
||||
if len(selector) == 1:
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
self.variable_dictionary[key].pop(hash_key, None)
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
@@ -154,3 +170,20 @@ class VariablePool(BaseModel):
|
||||
if isinstance(segment, FileSegment):
|
||||
return segment
|
||||
return None
|
||||
|
||||
def _add_system_variables(self, system_variable: SystemVariable):
|
||||
sys_var_mapping = system_variable.to_dict()
|
||||
for key, value in sys_var_mapping.items():
|
||||
if value is None:
|
||||
continue
|
||||
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
||||
# If the system variable already exists, do not add it again.
|
||||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
"""Create an empty variable pool."""
|
||||
return cls(system_variables=SystemVariable.empty())
|
||||
|
||||
@@ -334,7 +334,7 @@ class Graph(BaseModel):
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
@@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = {}
|
||||
"""outputs"""
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user