Compare commits

..

149 Commits

Author SHA1 Message Date
Stephen Zhou
d06ce2ef78 revert 2026-04-08 19:51:56 +08:00
Stephen Zhou
abcf4a5730 try disable csp 2026-04-08 19:06:15 +08:00
Stephen Zhou
5b3616aa33 Revert "try disable csp for test"
This reverts commit 19ab594c72.
2026-04-08 19:05:33 +08:00
Stephen Zhou
19ab594c72 try disable csp for test 2026-04-08 18:55:05 +08:00
Stephen Zhou
b64e930771 Merge branch 'main' into deploy/dev 2026-04-08 18:50:29 +08:00
Stephen Zhou
63bfba0bdb fix: update how ky handle error (#34735) 2026-04-08 10:38:33 +00:00
Coding On Star
9948a51b14 test: add unit tests for access control components to enhance coverage and reliability (#34722)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-08 08:50:57 +00:00
s-kawamura-upgrade
0e0bb3582f feat(web): add ALLOW_INLINE_STYLES env var to opt-in inline CSS in Markdown rendering (#34719)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-08 08:38:24 +00:00
Stephen Zhou
40bca2ad9c Merge branch 'main' into deploy/dev 2026-04-08 16:08:54 +08:00
Stephen Zhou
546062d2cd chore: remove raw vite deps (#34726) 2026-04-08 07:49:53 +00:00
Stephen Zhou
aad0b3c157 build: include vinext in docker build (#34535)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-04-08 07:26:39 +00:00
hj24
ef7dc9eabb Merge branch 'feat/new-biliing-quota' into deploy/dev 2026-04-08 15:02:50 +08:00
hj24
ae01a5d137 fix: unit test mock 2026-04-08 14:42:52 +08:00
hj24
ad6670ebcc fix: correct quota info response 2026-04-08 14:23:57 +08:00
hj24
8ca0917044 Merge branch 'main' into feat/new-biliing-quota 2026-04-08 13:39:24 +08:00
corevibe555
4d4265f531 refactor(api): deduplicate Pydantic models across fields and controllers (#34718) 2026-04-08 05:20:00 +00:00
Will
e138523123 fix: legacy model_type deserialization regression (#34717) 2026-04-08 05:08:12 +00:00
carlos4s
a65e1f71b4 refactor: use sessionmaker in small services 2 (#34696)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-08 05:06:50 +00:00
yyh
909c062ee1 fix(web): avoid prehydration script in slider (#34676) 2026-04-08 04:03:19 +00:00
hj24
f5322e45fc refactor: enhance billing info response handling (#34340)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-08 03:49:35 +00:00
Stephen Zhou
017f09f1e9 ci: update web changes scope (#34713) 2026-04-08 03:24:41 +00:00
corevibe555
0ba66ab155 refactor(api): deduplicate shared controller request schemas into controller_schemas.py (#34700)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-08 03:10:04 +00:00
corevibe555
5cd267d755 refactor(api): deduplicate RAG index entities and consolidate import paths (#34690) 2026-04-08 02:49:40 +00:00
Stephen Zhou
d30946dabf chore: update deps (#34704) 2026-04-08 02:45:30 +00:00
wangxiaolei
b0e524213e fix: backendModelConfig.chat_prompt_config.prompt is undefined (#34709) 2026-04-08 02:29:18 +00:00
corevibe555
b1adb5652e refactor(api): deduplicate I18nObject in datasource entities (#34701) 2026-04-08 01:36:56 +00:00
corevibe555
c825d5dcf6 refactor(api): tighten types for Tenant.custom_config_dict and MCPToolProvider.headers (#34698) 2026-04-08 01:36:42 +00:00
Renzo
2127d5850f refactor: replace untyped dicts with TypedDict in VDB config classes (#34697) 2026-04-08 00:57:11 +00:00
carlos4s
ae9fcc2969 refactor: use sessionmaker in controllers, events, models, and tasks 1 (#34693) 2026-04-07 23:47:20 +00:00
corevibe555
624db69f12 refactor(api): remove duplicated RAG entities from services layer (#34689) 2026-04-07 23:36:59 +00:00
corevibe555
80a7843f45 refactor(api): migrate consumers to shared RAG domain entities from core/rag/entities/ (#34692) 2026-04-07 23:22:56 +00:00
Renzo
cb55176612 refactor: migrate session.query to select API in small task files batch (#34684)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-07 22:58:23 +00:00
Statxc
5aa2524d33 refactor(api): type I18nObject.to_dict with I18nObjectDict TypedDict (#34680) 2026-04-07 22:57:32 +00:00
Pulakesh
2575a3a3ab refactor(api): clean up AssistantPromptMessage typing in CotChatAgentRunner (#34681) 2026-04-07 22:53:14 +00:00
corevibe555
f8f7b0ec1a refactor(api): deduplicate shared auth request payloads into auth_entities.py (#34694) 2026-04-07 22:51:11 +00:00
corevibe555
d2ee486900 refactor(api): extract shared RAG domain entities into core/rag/entity (#34685) 2026-04-07 22:43:37 +00:00
Statxc
c44ddd9831 refactor(api): type Chroma and AnalyticDB config params dicts with TypedDicts (#34678) 2026-04-07 13:27:12 +00:00
Statxc
e645cbd8f8 refactor(api): type VDB config params dicts with TypedDicts (#34677) 2026-04-07 13:23:42 +00:00
YBoy
485fc2c416 refactor(api): type Tenant custom config with TypedDict and tighten MCP headers type (#34670) 2026-04-07 13:18:19 +00:00
YBoy
f09be969bb refactor(api): type single-node graph structure with TypedDicts in workflow_entry (#34671) 2026-04-07 13:18:00 +00:00
Statxc
597a0b4d9f refactor(api): type indexing result with IndexingResultDict TypedDict (#34672) 2026-04-07 13:17:39 +00:00
Statxc
779cce3c61 refactor(api): type gen_index_struct_dict with VectorIndexStructDict TypedDict (#34675) 2026-04-07 13:17:20 +00:00
Statxc
b5d9a71cf9 refactor(api): type VDB to_index_struct with VectorIndexStructDict TypedDict (#34674) 2026-04-07 13:17:04 +00:00
corevibe555
c2af415450 refactor(api): Extract shared ResponseModel (#34633) 2026-04-07 13:05:38 +00:00
Dream
89ce61cfea refactor(api): replace json.loads with Pydantic validation in security and tools layers (#34380) 2026-04-07 12:11:51 +00:00
yyh
05c5327f47 chore: remove unused pnpm overrides (#34658) 2026-04-07 09:36:49 +00:00
yyh
3891c0a255 fix(workflow): correct env variable picker validation (#34666) 2026-04-07 09:34:25 +00:00
非法操作
63b1d0c1ea fix: var input label missing icon (#34569) 2026-04-07 09:33:13 +00:00
Pulakesh
75ed38fb3d fix(#34636): replace SimpleNamespace with MagicMock(spec=App) in test_app_dsl_service (#34659) 2026-04-07 07:25:46 +00:00
Statxc
63db9a7a2f refactor(api): type load balancing config dicts with TypedDict (#34639) 2026-04-07 05:58:10 +00:00
Statxc
19c80f0f0e refactor(api): type error stream response with TypedDict (#34641) 2026-04-07 05:57:42 +00:00
YBoy
c5a0bde3ec refactor(api): type aliyun trace utils with TypedDict and tighten return types (#34642) 2026-04-07 05:57:22 +00:00
YBoy
1261e5e5e8 refactor(api): type webhook validation result and workflow inputs with TypedDict (#34645) 2026-04-07 05:57:02 +00:00
Renzo
e2ecd68556 refactor: migrate session.query to select API in rag pipeline task files (#34648) 2026-04-07 05:56:19 +00:00
Pulakesh
bceb0eee9b refactor(api): migrate dict returns to TypedDicts in billing service (#34649) 2026-04-07 05:56:02 +00:00
Renzo
173e818a62 refactor: migrate session.query to select API in summary and remove document tasks (#34650) 2026-04-07 05:55:31 +00:00
YBoy
84d8940dbf refactor(api): type app parameter feature toggles with FeatureToggleD… (#34651) 2026-04-07 05:53:50 +00:00
Renzo
3e995e6a6d refactor: migrate session.query to select API in document task files (#34646) 2026-04-07 05:53:21 +00:00
yyh
459c36f21b fix: improve app delete alert dialog UX (#34644)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-07 05:03:39 +00:00
Renzo
72adb5468c refactor: migrate session.query to select API in retrieval_service (#34638) 2026-04-07 04:46:30 +00:00
Renzo
1194957fde refactor: migrate session.query to select API in end_user_service and small tasks (#34620) 2026-04-07 04:25:55 +00:00
Renzo
68bd29eda2 refactor: migrate session.query to select API in sync task and services (#34619) 2026-04-07 04:23:14 +00:00
YBoy
f67a811f7f refactor: replace dict params with BaseModel payloads in TagService (#34422)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-07 04:20:02 +00:00
yyh
b9c122e7f4 fix: simplify pre-commit hook flow (#34637) 2026-04-07 04:19:31 +00:00
aliworksx08
396b39dff9 refactor: migrate session.query to select API in console controllers (#34607) 2026-04-07 04:19:30 +00:00
Renzo
ac8bd12609 refactor: migrate session.query to select API in small task files (#34617) 2026-04-07 04:13:22 +00:00
Renzo
b55bef4438 refactor: migrate session.query to select API in core misc modules (#34608) 2026-04-07 04:08:34 +00:00
非法操作
2f9667de76 fix: web app user avatar display incorrect black (#34624) 2026-04-07 03:23:56 +00:00
Statxc
a7b6307d32 refactor(api): type dataset service dicts with TypedDict (#34625) 2026-04-07 02:10:52 +00:00
Statxc
2883ad6764 refactor(api): type plugin migration results with TypedDict (#34627) 2026-04-07 02:10:23 +00:00
Pulakesh
0feff5b048 refactor(api): enforce strict typing on retrieval_model to resolve FIXME (#34614) 2026-04-07 01:10:53 +00:00
Statxc
0bce6b35b4 refactor(api): type LLM generator results with TypedDict (#34621) 2026-04-07 01:06:08 +00:00
YBoy
89e23456f0 refactor(api): type invitation detail with InvitationDetailDict TypedDict (#34613)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-07 01:03:31 +00:00
Jake Armstrong
a39173c227 refactor(api): type notification response with NotificationResponseDict TypedDict (#34616) 2026-04-07 01:03:18 +00:00
YBoy
12e93d374f refactor(api): type MCP tool schema and arguments with TypedDict (#34612) 2026-04-07 01:02:06 +00:00
YBoy
922f9242e4 refactor(api): type crawl status dicts with CrawlStatusDict TypedDict (#34611) 2026-04-07 01:01:04 +00:00
YBoy
7fc0a791a2 refactor(api): type document summary status detail with TypedDict (#34610) 2026-04-07 01:00:39 +00:00
YBoy
8d37116fec refactor(api): type storage statistics with StorageStatisticsDict TypedDict (#34609)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-07 00:59:32 +00:00
dependabot[bot]
4b500f988d chore(deps-dev): bump the dev group across 1 directory with 20 updates (#34601)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 13:24:31 +00:00
YBoy
5ad906ea6a refactor(api): type workflow run related counts with RelatedCountsDict TypedDict (#34530) 2026-04-06 13:17:01 +00:00
dependabot[bot]
5b862a43e0 chore(deps-dev): bump the dev group in /api with 6 updates (#34579)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-06 11:49:54 +00:00
YBoy
1e5cd69205 refactor(api): type archive manifest with ArchiveManifestDict TypedDict (#34594) 2026-04-06 11:35:11 +00:00
Jake Armstrong
9081c46565 refactor(api): type upload file serialization with UploadFileDict TypedDict (#34589) 2026-04-06 11:34:52 +00:00
dependabot[bot]
40b252be8c chore(deps): bump google-auth-httplib2 from 0.3.0 to 0.3.1 in /api in the google group (#34575)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-06 11:32:30 +00:00
dependabot[bot]
ba1357038a chore(deps): update flask-compress requirement from <1.24,>=1.17 to >=1.17,<1.25 in /api in the flask group (#34573)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-06 11:32:19 +00:00
dependabot[bot]
46d1f4c338 chore(deps-dev): bump the vdb group in /api with 7 updates (#34586)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:31:36 +00:00
YBoy
9c880dd650 refactor(api): type orphaned draft variable stats with TypedDict (#34590) 2026-04-06 11:30:53 +00:00
YBoy
01ba0e050f refactor(api): reuse IdentityDict TypedDict in logging filters (#34593) 2026-04-06 11:30:21 +00:00
dependabot[bot]
ccc4aae94e chore(deps): bump the llm group in /api with 3 updates (#34583)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:30:02 +00:00
dependabot[bot]
01242e13d7 chore(deps): bump sqlalchemy from 2.0.48 to 2.0.49 in /api in the database group (#34584)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:29:50 +00:00
dependabot[bot]
938ee27e42 chore(deps): bump the github-actions-dependencies group with 4 updates (#34582)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:29:07 +00:00
dependabot[bot]
a101f72153 chore(deps): bump the google group in /api with 4 updates (#34581)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:29:00 +00:00
dependabot[bot]
40642433d8 chore(deps): bump flask-compress from 1.23 to 1.24 in /api in the flask group (#34580)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:28:25 +00:00
dependabot[bot]
8979181d5e chore(deps): bump boto3 from 1.42.78 to 1.42.83 in /api in the storage group (#34578)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-06 11:27:58 +00:00
dependabot[bot]
c17c6b5c35 chore(deps): bump the storage group in /api with 2 updates (#34585)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-06 11:27:26 +00:00
kurokobo
e83a4090ac fix: lighten the health checks for the Worker and Worker Beat services, and disable them by default (#34572) 2026-04-06 02:26:26 +00:00
YBoy
b71b9f80b9 refactor(api): type workflow run delete/count results with RunsWithRelatedCountsDict TypedDict (#34531)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-05 16:11:41 +00:00
agenthaulk
ee87289917 refactor: convert AppMode if/elif to match/case in app_generate_service (#30001) (#34563)
Co-authored-by: agenthaulk <agenthaulk@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-05 10:23:51 +00:00
agenthaulk
5ad8c3e249 refactor: convert AppMode if/elif to match/case in service files (#30001) (#34562)
Co-authored-by: agenthaulk <agenthaulk@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-05 10:22:11 +00:00
agenthaulk
8b992513b8 refactor: convert ProviderQuotaType if/elif to match/case (#30001) (#34561)
Co-authored-by: agenthaulk <agenthaulk@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-05 10:20:18 +00:00
Renzo
eca0cdc7a9 refactor: select in dataset_service (SegmentService and remaining cla… (#34547)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-05 00:13:06 +00:00
Renzo
779e6b8e0b refactor: select in datasource_provider_service (#34548)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-05 00:12:15 +00:00
Renzo
c2428361c4 refactor: select in dataset_service (DocumentService class) (#34528)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 22:52:01 +00:00
Renzo
68e4d13f36 refactor: select in annotation_service (#34503)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 22:47:22 +00:00
Stephen Zhou
cb9f4bb100 build: include packages in docker build (#34532) 2026-04-03 13:40:16 +00:00
YBoy
8a398f3105 refactor(api): type messages cleanup stats with MessagesCleanStatsDict TypedDict (#34527) 2026-04-03 12:29:41 +00:00
YBoy
0f051d5886 refactor(api): type celery sqlcommenter tags with CelerySqlcommenterTagsDict TypedDict (#34526) 2026-04-03 12:06:15 +00:00
Renzo
e85d9a0d72 refactor: select in dataset_service (DatasetService class) (#34525)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 12:01:31 +00:00
Renzo
06dde4f503 refactor: select in account_service (TenantService class) (#34499)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-03 11:03:45 +00:00
Coding On Star
83d4176785 test: add unit tests for app store and annotation components, enhancing coverage for state management and UI interactions (#34510)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 09:09:59 +00:00
yyh
c94951b2f8 refactor(web): migrate notion page selectors to tanstack virtual (#34508)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 07:03:12 +00:00
Matt Van Horn
a9cf8f6c5d refactor(web): replace react-syntax-highlighter with shiki (#33473)
Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 06:40:26 +00:00
YBoy
64ddec0d67 refactor(api): type annotation service dicts with TypedDict (#34482)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-03 06:25:52 +00:00
Renzo
da3b0caf5e refactor: select in account_service (RegisterService class) (#34500)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 06:21:26 +00:00
Stephen Zhou
4fedd43af5 chore: update code-inspector-plugin to 1.5.1 (#34506) 2026-04-03 05:34:03 +00:00
yyh
a263f28e19 fix(web): restore ui select public exports (#34501) 2026-04-03 04:42:02 +00:00
Stephen Zhou
d53862f135 chore: override lodash (#34502) 2026-04-03 04:40:46 +00:00
hj24
b2861e019b fix: merge error 2026-04-02 18:16:31 +08:00
Joel
cad9936c0a Merge branch 'fix/ps-not-send' into deploy/dev 2026-04-02 17:55:04 +08:00
hj24
8c0b596ced Merge branch 'chore-debug-partnerstack' into deploy/dev 2026-04-02 17:54:06 +08:00
Joel
65e434cf06 chore: add debug 2026-04-02 17:53:52 +08:00
hj24
12a0f85b72 feat: clear api 2026-04-02 17:52:55 +08:00
hj24
1fdb653875 feat: debug partnerstack 2026-04-02 17:18:25 +08:00
hj24
4ba8c71962 feat: debug partnerstack 2026-04-02 17:17:40 +08:00
Joel
1f1c74099f Merge branch 'fix/ps-not-send' into deploy/dev 2026-04-02 12:53:28 +08:00
Joel
359007848d chore: remove save binded cookie 2026-04-02 12:53:07 +08:00
Joel
43fedac47b Merge branch 'fix/ps-not-send' into deploy/dev 2026-04-02 11:23:20 +08:00
Joel
20ddc9c48a fix: url query change record cookie 2026-04-02 11:22:46 +08:00
hj24
a91c1a2af0 Merge branch 'refactor-enhance-billing-info-guard' into deploy/dev 2026-04-02 11:02:00 +08:00
Yansong Zhang
b3870524d4 fix usage get 2026-04-02 09:52:52 +08:00
hj24
919c080452 chore: update comments 2026-04-01 10:35:34 +08:00
hj24
4653ed7ead refactor: enhance billing info response handling 2026-03-31 18:23:32 +08:00
Yansong Zhang
c543188434 fix linter 2026-03-31 15:22:51 +08:00
Yansong Zhang
f319a9e42f fix test case 2026-03-31 15:22:43 +08:00
Yansong Zhang
58241a89a5 fix linter 2026-03-31 14:59:54 +08:00
Yansong Zhang
422bf3506e rebuild quota service 2026-03-31 14:59:45 +08:00
Yansong Zhang
6e745f9e9b fix linter 2026-03-31 09:49:24 +08:00
Yansong Zhang
4e50d55339 fix comment 2026-03-31 09:49:09 +08:00
autofix-ci[bot]
b95cdabe26 [autofix.ci] apply automated fixes 2026-03-30 08:45:37 +00:00
Yansong Zhang
daa47c25bb Merge branch 'feat/new-biliing-quota' of github.com:langgenius/dify into feat/new-biliing-quota 2026-03-30 16:43:13 +08:00
Yansong Zhang
f1bcd6d715 add test case for quota and billing service 2026-03-30 16:41:56 +08:00
hj24
8643ff43f5 Merge branch 'main' into feat/new-biliing-quota 2026-03-30 15:57:49 +08:00
Yansong Zhang
c5f30a47f0 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-30 15:26:38 +08:00
Yansong Zhang
37d438fa19 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-27 16:26:09 +08:00
Yansong Zhang
9503803997 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-23 09:27:39 +08:00
Yansong Zhang
d6476f5434 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-20 15:17:27 +08:00
Yansong Zhang
80b4633e8f fix style check and test 2026-03-20 14:58:31 +08:00
autofix-ci[bot]
3888969af3 [autofix.ci] apply automated fixes 2026-03-20 05:45:30 +00:00
Yansong Zhang
658ac15589 use new quota system 2026-03-20 13:29:22 +08:00
660 changed files with 32704 additions and 11481 deletions

9
.github/labeler.yml vendored
View File

@@ -1,3 +1,10 @@
web:
- changed-files:
- any-glob-to-any-file: 'web/**'
- any-glob-to-any-file:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'

View File

@@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods

View File

@@ -39,9 +39,11 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'

View File

@@ -65,7 +65,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
@@ -130,7 +130,7 @@ jobs:
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}

View File

@@ -8,9 +8,11 @@ on:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:

View File

@@ -65,9 +65,11 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@@ -77,9 +79,11 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'

View File

@@ -77,9 +77,11 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@@ -149,7 +151,7 @@ jobs:
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@@ -9,6 +9,7 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@@ -240,7 +240,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -95,31 +95,5 @@ if $web_modified; then
exit 1
fi
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
if [ -n "$modified_files" ]; then
for file in $modified_files; do
test_file="${file%.*}.spec.ts"
echo "Checking for test file: $test_file"
# check if the test file exists
if [ -f "../$test_file" ]; then
echo "Detected changes in $file, running corresponding unit tests..."
pnpm run test "../$test_file"
if [ $? -ne 0 ]; then
echo "Unit tests failed. Please fix the errors before committing."
exit 1
fi
echo "Unit tests for $file passed."
else
echo "Warning: $file does not have a corresponding test file."
fi
done
echo "All unit tests for modified web/utils files have passed."
fi
cd ../
fi

18
api/celery_healthcheck.py Normal file
View File

@@ -0,0 +1,18 @@
# This module provides a lightweight Celery instance for use in Docker health checks.
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
# Using this module keeps the health check fast and low-cost.
from celery import Celery
from configs import dify_config
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
broker_transport_options = get_celery_broker_transport_options()
if broker_transport_options:
celery.conf.update(broker_transport_options=broker_transport_options)
ssl_options = get_celery_ssl_options()
if ssl_options:
celery.conf.update(broker_use_ssl=ssl_options)

View File

@@ -1,7 +1,7 @@
import datetime
import logging
import time
from typing import Any
from typing import TypedDict
import click
import sqlalchemy as sa
@@ -503,7 +503,19 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
return [row[0] for row in result]
def _count_orphaned_draft_variables() -> dict[str, Any]:
class _AppOrphanCounts(TypedDict):
variables: int
files: int
class OrphanedDraftVariableStatsDict(TypedDict):
total_orphaned_variables: int
total_orphaned_files: int
orphaned_app_count: int
orphaned_by_app: dict[str, _AppOrphanCounts]
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
"""
Count orphaned draft variables by app, including associated file counts.
@@ -526,7 +538,7 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
with db.engine.connect() as conn:
result = conn.execute(sa.text(variables_query))
orphaned_by_app = {}
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
total_files = 0
for row in result:

View File

@@ -0,0 +1,63 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from libs.helper import UUIDStrOrEmpty
# --- Conversation schemas ---
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
# --- Message schemas ---
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
# --- Saved message schemas ---
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
# --- Workflow schemas ---
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None

View File

@@ -2,6 +2,7 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import cast
from flask import request
from flask_restx import Resource
@@ -17,7 +18,7 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
from services.billing_service import BillingService, LangContentDict
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -328,7 +329,7 @@ class UpsertNotificationApi(Resource):
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[c.model_dump() for c in payload.contents],
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,

View File

@@ -7,7 +7,7 @@ from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest
@@ -26,9 +26,11 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
@@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
NotionIcon,
NotionInfo,
NotionPage,
PreProcessingRule,
RerankingModel,
Rule,
Segmentation,
WebsiteInfo,
WeightKeywordSetting,
WeightModel,
@@ -155,16 +154,6 @@ class AppTracePayload(BaseModel):
type JSONValue = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())

View File

@@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session)
# Import app
account = current_user

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@@ -59,10 +60,8 @@ class ChatMessagesQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod

View File

@@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
.limit(1)
)
if not webhook_trigger:

View File

@@ -3,7 +3,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from libs.password import hash_password
from services.account_service import AccountService, TenantService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")

View File

@@ -1,5 +1,3 @@
from typing import Any
import flask_login
from flask import make_response, request
from flask_restx import Resource
@@ -42,8 +40,9 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@@ -51,9 +50,7 @@ from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
class LoginPayload(LoginPayloadBase):
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
@@ -101,7 +98,7 @@ class LoginApi(Resource):
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
invitation_data: InvitationDetailDict | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:

View File

@@ -1,4 +1,6 @@
import base64
import json
from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
@@ -9,6 +11,7 @@ from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -84,3 +87,39 @@ class PartnerTenants(Resource):
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
_DEBUG_KEY = "billing:debug"
_DEBUG_TTL = timedelta(days=7)
class DebugDataPayload(BaseModel):
type: str = Field(..., min_length=1, description="Data type key")
data: str = Field(..., min_length=1, description="Data value to append")
@console_ns.route("/billing/debug/data")
class DebugData(Resource):
def post(self):
body = DebugDataPayload.model_validate(request.get_json(force=True))
item = json.dumps({
"type": body.type,
"data": body.data,
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
})
redis_client.lpush(_DEBUG_KEY, item)
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
return {"result": "ok"}, 201
def get(self):
recent = request.args.get("recent", 10, type=int)
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
return {
"data": [
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
]
}
def delete(self):
redis_client.delete(_DEBUG_KEY)
return {"result": "ok"}

View File

@@ -3,6 +3,7 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@@ -86,8 +87,8 @@ class CustomizedPipelineTemplateApi(Resource):
@enterprise_license_required
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
template = session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
)
if not template:
raise ValueError("Customized pipeline template not found.")

View File

@@ -2,10 +2,10 @@ import logging
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
@@ -32,14 +32,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)

View File

@@ -1,10 +1,11 @@
from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
@@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@@ -3,9 +3,10 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
@@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
@@ -44,17 +44,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]

View File

@@ -1,28 +1,18 @@
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@@ -1,11 +1,10 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
@@ -34,12 +33,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)

View File

@@ -1,3 +1,5 @@
from typing import TypedDict
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@@ -11,6 +13,21 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationResponseDict(TypedDict):
should_show: bool
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
@@ -45,28 +62,30 @@ class NotificationApi(Resource):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
response: NotificationResponseDict
if not result.get("shouldShow"):
return {"should_show": False, "notifications": []}, 200
response = {"should_show": False, "notifications": []}
return response, 200
lang = current_user.interface_language or _FALLBACK_LANG
notifications = []
notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
notifications.append(
{
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
notifications.append(item)
return {"should_show": bool(notifications), "notifications": notifications}, 200
response = {"should_show": bool(notifications), "notifications": notifications}
return response, 200
@console_ns.route("/notification/dismiss")

View File

@@ -9,7 +9,14 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
dataset_tag_fields = {
"id": fields.String,
@@ -25,19 +32,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
type: TagType = Field(description="Tag type")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
type: TagType = Field(description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel):
@@ -82,7 +89,7 @@ class TagListApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@@ -103,7 +110,7 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -136,7 +143,9 @@ class TagBindingCreateApi(Resource):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
@@ -154,6 +163,8 @@ class TagBindingDeleteApi(Resource):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from functools import wraps
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
@@ -21,12 +22,12 @@ def plugin_permission_required(
tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session:
permission = (
session.query(TenantPluginPermission)
permission = session.scalar(
select(TenantPluginPermission)
.where(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not permission:

View File

@@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args.remove_webapp_brand,
custom_config_dict: TenantCustomConfigDict = {
"remove_webapp_brand": args.remove_webapp_brand
if args.remove_webapp_brand is not None
else tenant.custom_config_dict.get("remove_webapp_brand", False),
"replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),

View File

@@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@@ -4,6 +4,7 @@ from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.schema import register_schema_model
@@ -80,11 +81,11 @@ class MCPAppApi(Resource):
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.query(App).where(App.id == mcp_server.app_id).first()
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
if not app:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
@@ -190,12 +191,12 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return (
session.query(EndUser)
return session.scalar(
select(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
.limit(1)
)
def _create_end_user(

View File

@@ -2,11 +2,12 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
@@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
)
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")

View File

@@ -1,5 +1,4 @@
import logging
from typing import Literal
from flask import request
from flask_restx import Resource
@@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
@@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
@@ -27,17 +26,6 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Literal
from typing import Literal
from dateutil.parser import isoparse
from flask import request
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class WorkflowRunPayload(WorkflowRunPayloadBase):
response_mode: Literal["blocking", "streaming"] | None = None

View File

@@ -22,10 +22,17 @@ from fields.tag_fields import DataSetTag
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id
tag = TagService.update_tags(params, tag_id)
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource):
raise Forbidden()
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204
@@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource):
raise Forbidden()
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204

View File

@@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
@@ -40,11 +41,8 @@ from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
PreProcessingRule,
ProcessRule,
RetrievalModel,
Rule,
Segmentation,
)
from services.file_service import FileService
from services.summary_index_service import SummaryIndexService

View File

@@ -4,13 +4,23 @@ Serialization helpers for Service API knowledge pipeline endpoints.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from models.model import UploadFile
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
class UploadFileDict(TypedDict):
id: str
name: str
size: int
extension: str
mime_type: str | None
created_by: str
created_at: str | None
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
return {
"id": upload_file.id,
"name": upload_file.name,

View File

@@ -3,10 +3,11 @@ import logging
from flask import request
from flask_restx import fields, marshal_with
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, field_validator
from pydantic import field_validator
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@@ -34,12 +35,7 @@ from services.errors.audio import (
from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
class TextToAudioPayload(TextToAudioPayloadBase):
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:

View File

@@ -1,10 +1,11 @@
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
@@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)

View File

@@ -3,7 +3,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from libs.helper import extract_remote_ip
from libs.password import hash_password
from models.account import Account
from services.account_service import AccountService
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)

View File

@@ -29,13 +29,11 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
class LoginPayload(BaseModel):
email: EmailStr
password: str
class LoginPayload(LoginPayloadBase):
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:

View File

@@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",

View File

@@ -1,27 +1,17 @@
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@@ -1,11 +1,10 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
logger = logging.getLogger(__name__)
register_schema_models(web_ns, WorkflowRunPayload)

View File

@@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
content = ""
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"
content += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n"
content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message]
assistant_messages = [AssistantPromptMessage(content=content)]
# query messages
query_messages = self._organize_user_query(self._query, [])

View File

@@ -5,6 +5,10 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class FeatureToggleDict(TypedDict):
enabled: bool
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
@@ -16,12 +20,12 @@ class SystemParametersDict(TypedDict):
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
suggested_questions_after_answer: FeatureToggleDict
speech_to_text: FeatureToggleDict
text_to_speech: FeatureToggleDict
retriever_resource: FeatureToggleDict
annotation_reply: FeatureToggleDict
more_like_this: FeatureToggleDict
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]

View File

@@ -1,4 +1,3 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.entities import MetadataFilteringCondition
from models.model import AppMode
@@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel):
provider: str
name: str
@@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
completion_params: dict[str, Any] = Field(default_factory=dict)
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.

View File

@@ -107,13 +107,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception):
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses = {
error_responses: dict[type[Exception], dict[str, Any]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@@ -127,7 +127,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data = None
data: dict[str, Any] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import (
build_bootstrap_variables,

View File

@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
class QueueEvent(StrEnum):

View File

@@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
class AnnotationReplyAccount(BaseModel):

View File

@@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
agent_thought: MessageAgentThought | None = session.scalar(
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
)
if agent_thought:

View File

@@ -6,7 +6,7 @@ from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db

View File

@@ -345,8 +345,8 @@ class DatasourceManager:
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
upload_file = session.scalar(
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")

View File

@@ -1,22 +1,3 @@
from pydantic import BaseModel, Field, model_validator
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
__all__ = ["I18nObject", "I18nObjectDict"]

View File

@@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,

View File

@@ -1 +1,8 @@
from core.entities.plugin_credential_type import PluginCredentialType
DEFAULT_PLUGIN_ID = "langgenius"
__all__ = [
"DEFAULT_PLUGIN_ID",
"PluginCredentialType",
]

View File

@@ -0,0 +1,9 @@
import enum
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@@ -22,6 +22,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.entities import PluginCredentialType
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
@@ -46,7 +47,6 @@ from models.provider import (
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@@ -2,7 +2,7 @@
Credential utility functions for checking credential existence and policy compliance.
"""
from services.enterprise.plugin_manager_service import PluginCredentialType
from core.entities import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, cast
from typing import Protocol, TypedDict, cast
import json_repair
from graphon.enums import WorkflowNodeExecutionMetadataKey
@@ -49,6 +49,17 @@ class WorkflowServiceInterface(Protocol):
pass
class CodeGenerateResultDict(TypedDict):
code: str
language: str
error: str
class StructuredOutputResultDict(TypedDict):
output: str
error: str
class LLMGenerator:
@classmethod
def generate_conversation_name(
@@ -293,7 +304,7 @@ class LLMGenerator:
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
):
) -> CodeGenerateResultDict:
if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
@@ -362,7 +373,9 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
def generate_structured_output(
cls, tenant_id: str, args: RuleStructuredOutputPayload
) -> StructuredOutputResultDict:
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@@ -454,7 +467,7 @@ class LLMGenerator:
):
session = db.session()
app: App | None = session.query(App).where(App.id == flow_id).first()
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
if not app:
raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@@ -6,6 +6,7 @@ import logging
import flask
from core.logging.context import get_request_id, get_trace_id
from core.logging.structured_formatter import IdentityDict
class TraceContextFilter(logging.Filter):
@@ -60,7 +61,7 @@ class IdentityContextFilter(logging.Filter):
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
def _extract_identity(self) -> IdentityDict:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
@@ -77,7 +78,7 @@ class IdentityContextFilter(logging.Filter):
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
identity: IdentityDict = {}
if isinstance(user, Account):
if user.current_tenant_id:

View File

@@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, NotRequired, TypedDict, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@@ -15,6 +15,17 @@ from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class ToolParameterSchemaDict(TypedDict):
type: str
properties: dict[str, Any]
required: list[str]
class ToolArgumentsDict(TypedDict):
query: NotRequired[str]
inputs: dict[str, Any]
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
@@ -119,7 +130,7 @@ def handle_list_tools(
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=parameter_schema,
inputSchema=cast(dict[str, Any], parameter_schema),
)
],
)
@@ -154,7 +165,7 @@ def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> dict[str, Any]:
) -> ToolParameterSchemaDict:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
@@ -174,7 +185,7 @@ def build_parameter_schema(
}
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}

View File

@@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from configs import dify_config
from core.entities import PluginCredentialType
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
@@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -56,10 +56,22 @@ def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
return links
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
documents_data = []
class RetrievalDocumentMetadataDict(TypedDict):
dataset_id: Any
doc_id: Any
document_id: Any
class RetrievalDocumentDict(TypedDict):
content: str
metadata: RetrievalDocumentMetadataDict
score: Any
def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]:
documents_data: list[RetrievalDocumentDict] = []
for document in documents:
document_data = {
document_data: RetrievalDocumentDict = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
@@ -83,7 +95,7 @@ def create_common_span_attributes(
framework: str = DEFAULT_FRAMEWORK_NAME,
inputs: str = "",
outputs: str = "",
) -> dict[str, Any]:
) -> dict[str, str]:
return {
GEN_AI_SESSION_ID: session_id,
GEN_AI_USER_ID: user_id,

View File

@@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@@ -0,0 +1,5 @@
from core.plugin.entities.oauth import OAuthSchema
__all__ = [
"OAuthSchema",
]

View File

@@ -1,5 +1,3 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
@@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
OAuth schema
"""
client_schema: Sequence[ProviderConfig] = Field(
client_schema: list[ProviderConfig] = Field(
default_factory=list,
description="client schema like client_id, client_secret, etc.",
)
credentials_schema: Sequence[ProviderConfig] = Field(
credentials_schema: list[ProviderConfig] = Field(
default_factory=list,
description="credentials schema like access_token, refresh_token, etc.",
)

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -15,6 +14,7 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -58,6 +58,8 @@ from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class ProviderManager:
"""
@@ -875,8 +877,8 @@ class ProviderManager:
return {"openai_api_key": encrypted_config}
try:
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
credentials = _credentials_adapter.validate_json(encrypted_config)
except (ValueError, JSONDecodeError):
return {}
# Decrypt secret variables
@@ -1015,7 +1017,7 @@ class ProviderManager:
if not cached_provider_credentials:
provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config:
provider_credentials = json.loads(provider_records[0].encrypted_config)
provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
@@ -1162,8 +1164,10 @@ class ProviderManager:
if not cached_provider_model_credentials:
try:
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
except JSONDecodeError:
provider_model_credentials = _credentials_adapter.validate_json(
load_balancing_model_config.encrypted_config
)
except (ValueError, JSONDecodeError):
continue
# Get decoding rsa key and cipher for decrypting credentials
@@ -1176,7 +1180,7 @@ class ProviderManager:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
provider_model_credentials.get(variable) or "",
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)

View File

@@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.entities import MetadataFilteringCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
@@ -182,7 +182,9 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
if metadata_filtering_conditions
else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
@@ -240,7 +242,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:
return session.query(Dataset).where(Dataset.id == dataset_id).first()
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
@classmethod
def keyword_search(
@@ -573,15 +575,13 @@ class RetrievalService:
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
summaries = (
session.query(DocumentSegmentSummary)
.filter(
summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
)
.all()
)
).all()
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
@@ -851,12 +851,12 @@ class RetrievalService:
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
if upload_file:
attachment_binding = (
session.query(SegmentAttachmentBinding)
attachment_binding = session.scalar(
select(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first()
.limit(1)
)
if attachment_binding:
attachment_info: AttachmentInfoDict = {
@@ -875,14 +875,12 @@ class RetrievalService:
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
).all()
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, TypedDict
from pydantic import BaseModel, model_validator
@@ -13,6 +13,13 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbClientParamsDict(TypedDict):
access_key_id: str
access_key_secret: str
region_id: str
read_timeout: int
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
@@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
result: AnalyticdbClientParamsDict = {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
return result
class AnalyticdbVectorOpenAPI:

View File

@@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -85,8 +85,12 @@ class BaiduVector(BaseVector):
def get_type(self) -> str:
return VectorType.BAIDU
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))

View File

@@ -1,12 +1,12 @@
import json
from typing import Any
from typing import Any, TypedDict
import chromadb
from chromadb import QueryResult, Settings
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
class ChromaParamsDict(TypedDict):
host: str
port: int
ssl: bool
tenant: str
database: str
settings: Settings
class ChromaConfig(BaseModel):
host: str
port: int
@@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
auth_provider: str | None = None
auth_credentials: str | None = None
def to_chroma_params(self):
def to_chroma_params(self) -> ChromaParamsDict:
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials,
)
return {
result: ChromaParamsDict = {
"host": self.host,
"port": self.port,
"ssl": False,
@@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
"database": self.database,
"settings": settings,
}
return result
class ChromaVector(BaseVector):
@@ -145,7 +154,10 @@ class ChromaVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
index_struct_dict: VectorIndexStructDict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, TypedDict
from packaging import version
from pydantic import BaseModel, model_validator
@@ -20,6 +20,15 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MilvusParamsDict(TypedDict):
uri: str
token: str | None
user: str | None
password: str | None
db_name: str
analyzer_params: str | None
class MilvusConfig(BaseModel):
"""
Configuration class for Milvus connection.
@@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
def to_milvus_params(self) -> MilvusParamsDict:
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
result: MilvusParamsDict = {
"uri": self.uri,
"token": self.token,
"user": self.user,
@@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
"db_name": self.database,
"analyzer_params": self.analyzer_params,
}
return result
class MilvusVector(BaseVector):
@@ -352,6 +362,7 @@ class MilvusVector(BaseVector):
# Create Index params for the collection
index_params_obj = IndexParams()
assert index_params is not None
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection

View File

@@ -22,7 +22,7 @@ from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -94,8 +94,12 @@ class QdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.QDRANT
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any
from typing import Any, TypedDict
from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore
@@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -23,6 +23,13 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class TencentParamsDict(TypedDict):
url: str
username: str | None
key: str | None
timeout: float
class TencentConfig(BaseModel):
url: str
api_key: str | None = None
@@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
def to_tencent_params(self) -> TencentParamsDict:
result: TencentParamsDict = {
"url": self.url,
"username": self.username,
"key": self.api_key,
"timeout": self.timeout,
}
return result
bm25 = BM25Encoder.default("zh")
@@ -83,8 +96,12 @@ class TencentVector(BaseVector):
def get_type(self) -> str:
return VectorType.TENCENT
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def _has_collection(self) -> bool:
return bool(

View File

@@ -25,7 +25,7 @@ from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:

View File

@@ -1,11 +1,20 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, TypedDict
from core.rag.models.document import Document
class VectorStoreDict(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: str
vector_store: VectorStoreDict
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name

View File

@@ -9,7 +9,7 @@ from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
@@ -30,8 +30,11 @@ class AbstractVectorFactory(ABC):
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
return index_struct_dict

View File

@@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@@ -184,9 +184,13 @@ class WeaviateVector(BaseVector):
dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
def to_index_struct(self) -> VectorIndexStructDict:
"""Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""

View File

@@ -0,0 +1,28 @@
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
__all__ = [
"Condition",
"DatasourceCompletedEvent",
"DatasourceErrorEvent",
"DatasourceProcessingEvent",
"DocumentContext",
"EconomySetting",
"EmbeddingSetting",
"IndexMethod",
"KeywordSetting",
"MetadataFilteringCondition",
"ParentMode",
"PreProcessingRule",
"RetrievalSourceMetadata",
"Rule",
"Segmentation",
"SupportedComparisonOperator",
"VectorSetting",
"WeightedScoreConfig",
]

View File

@@ -0,0 +1,30 @@
from typing import Literal
from pydantic import BaseModel
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting

View File

@@ -38,9 +38,9 @@ class Condition(BaseModel):
value: str | Sequence[str] | None | int | float = None
class MetadataCondition(BaseModel):
class MetadataFilteringCondition(BaseModel):
"""
Metadata Condition.
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"

View File

@@ -0,0 +1,27 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None

View File

@@ -0,0 +1,28 @@
from pydantic import BaseModel
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting

View File

@@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory
@@ -61,7 +61,7 @@ class IndexProcessor:
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
) -> IndexingResultDict:
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:
@@ -129,7 +129,7 @@ class IndexProcessor:
}
)
return {
result: IndexingResultDict = {
"dataset_id": dataset_id,
"dataset_name": dataset_name_value,
"batch": batch,
@@ -138,6 +138,7 @@ class IndexProcessor:
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
return result
def get_preview_output(
self,

View File

@@ -32,6 +32,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
@@ -49,7 +50,6 @@ from models.account import Account
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()

View File

@@ -17,6 +17,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import ParentMode, Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
@@ -30,7 +31,6 @@ from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)

View File

@@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -30,7 +31,6 @@ from libs import helper
from models.account import Account
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)

View File

@@ -1,16 +1,6 @@
from pydantic import BaseModel
class VectorSetting(BaseModel):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
keyword_weight: float
from core.rag.entities import KeywordSetting, VectorSetting
class Weights(BaseModel):

View File

@@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
@@ -604,7 +602,7 @@ class DatasetRetrieval:
planning_strategy: PlanningStrategy,
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
metadata_condition: MetadataFilteringCondition | None = None,
):
tools = []
for dataset in available_datasets:
@@ -743,7 +741,7 @@ class DatasetRetrieval:
reranking_enable: bool = True,
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
metadata_condition: MetadataFilteringCondition | None = None,
attachment_ids: list[str] | None = None,
):
if not available_datasets:
@@ -1063,7 +1061,7 @@ class DatasetRetrieval:
top_k: int,
all_documents: list[Document],
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
metadata_condition: MetadataFilteringCondition | None = None,
attachment_ids: list[str] | None = None,
):
with flask_app.app_context():
@@ -1339,7 +1337,7 @@ class DatasetRetrieval:
metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict,
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
@@ -1371,7 +1369,7 @@ class DatasetRetrieval:
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
metadata_condition = MetadataFilteringCondition(
logical_operator=metadata_filtering_conditions.logical_operator
if metadata_filtering_conditions
else "or", # type: ignore
@@ -1400,7 +1398,7 @@ class DatasetRetrieval:
expected_value,
filters,
)
metadata_condition = MetadataCondition(
metadata_condition = MetadataFilteringCondition(
logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
@@ -1723,7 +1721,7 @@ class DatasetRetrieval:
self,
flask_app: Flask,
available_datasets: list[Dataset],
metadata_condition: MetadataCondition | None,
metadata_condition: MetadataFilteringCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],
tenant_id: str,

View File

@@ -1,6 +1,15 @@
from typing import TypedDict
from pydantic import BaseModel, Field, model_validator
class I18nObjectDict(TypedDict):
zh_Hans: str | None
en_US: str
pt_BR: str | None
ja_JP: str | None
class I18nObject(BaseModel):
"""
Model class for i18n object.
@@ -18,5 +27,11 @@ class I18nObject(BaseModel):
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self):
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
def to_dict(self) -> I18nObjectDict:
result: I18nObjectDict = {
"zh_Hans": self.zh_Hans,
"en_US": self.en_US,
"pt_BR": self.pt_BR,
"ja_JP": self.ja_JP,
}
return result

View File

@@ -6,9 +6,20 @@ from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
ValidationInfo,
field_serializer,
field_validator,
model_validator,
)
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import (
MCPServerParameterType,
PluginParameter,
@@ -18,11 +29,19 @@ from core.plugin.entities.parameters import (
cast_parameter_value,
init_frontend_parameter,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
class EmojiIconDict(TypedDict):
background: str
content: str
emoji_icon_adapter: TypeAdapter[EmojiIconDict] = TypeAdapter(EmojiIconDict)
class ToolLabelEnum(StrEnum):
SEARCH = "search"
IMAGE = "image"
@@ -410,15 +429,6 @@ class ToolEntity(BaseModel):
return value or {}
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
)
credentials_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: str | None = None

View File

@@ -5,16 +5,19 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
import sqlalchemy as sa
from graphon.runtime import VariablePool
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from yarl import URL
import contexts
from configs import dify_config
from core.entities import PluginCredentialType
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -27,7 +30,6 @@ from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
@@ -49,9 +51,11 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
EmojiIconDict,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
emoji_icon_adapter,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
@@ -72,9 +76,7 @@ class ApiProviderControllerItem(TypedDict):
controller: ApiToolProviderController
class EmojiIconDict(TypedDict):
background: str
content: str
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class WorkflowToolRuntimeSpec(Protocol):
@@ -885,7 +887,7 @@ class ToolManager:
raise ValueError(f"you have not added provider {provider_name}")
try:
credentials = json.loads(provider_obj.credentials_str) or {}
credentials = _credentials_adapter.validate_json(provider_obj.credentials_str) or {}
except Exception:
credentials = {}
@@ -910,7 +912,7 @@ class ToolManager:
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)
icon = emoji_icon_adapter.validate_json(provider_obj.icon)
except Exception:
icon = {"background": "#252525", "content": "\ud83d\ude01"}
@@ -973,7 +975,7 @@ class ToolManager:
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
icon = json.loads(workflow_provider.icon)
icon = emoji_icon_adapter.validate_json(workflow_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@@ -990,7 +992,7 @@ class ToolManager:
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
icon = json.loads(api_provider.icon)
icon = emoji_icon_adapter.validate_json(api_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@@ -8,7 +8,7 @@ from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner

View File

@@ -6,8 +6,7 @@ from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval

Some files were not shown because too many files have changed in this diff Show More