Compare commits

...

94 Commits

Author SHA1 Message Date
Joel
333aae6795 fix: others headless change 2025-01-23 15:11:05 +08:00
Joel
e797e97b18 chore: some menu 2025-01-23 14:01:14 +08:00
Joel
4ad152f525 chore: change some 2025-01-23 10:30:49 +08:00
Joel
58403b8238 fix: build error 2025-01-22 16:14:04 +08:00
Joel
a465f092eb fix: headers to v15 2025-01-22 15:51:44 +08:00
Joel
f981eb4640 fix: params to v15 2025-01-22 15:41:32 +08:00
Joel
25aaf53375 chore: temp 2025-01-21 18:30:35 +08:00
Joel
5ec7920c42 fix: fe check 2025-01-14 14:38:55 +08:00
Joel
4c49e48465 chore: template and config 2025-01-09 16:23:25 +08:00
Joel
ed93954add feat: support config max chunk length by env in frontend 2025-01-09 14:42:01 +08:00
NFish
b7a4e3903e fix: add last_refresh_time to track the validity of is_other_tab_refreshing (#12517) 2025-01-09 10:40:45 +08:00
Hiroshi Fujita
b4c1c2f731 fix: Reverse sync docker-compose-template.yaml (#12509) 2025-01-09 10:21:22 +08:00
kurokobo
1b940e7daa feat: add ci job to test template for docker compose (#12514) 2025-01-09 00:04:58 +08:00
非法操作
f4ee50a7ad chore: improve app doc (#12490) 2025-01-08 18:37:12 +08:00
Jyong
bee32d960a fix #12453 #12482 (#12495) 2025-01-08 18:26:05 +08:00
YoungLH
040a3b782c FEAT: support milvus to full text search (#11430)
Signed-off-by: YoungLH <974840768@qq.com>
2025-01-08 17:39:53 +08:00
非法操作
d649037c3e feat: support single run doc extractor node (#11318) 2025-01-08 15:20:15 +08:00
-LAN-
0a49d3dd52 fix: tiktoken cannot be loaded without internet (#12478)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-08 14:49:44 +08:00
Yingchun Lai
53bb37b749 fix: fix the incorrect plaintext file key when saving (#10429) 2025-01-08 12:52:45 +08:00
Hiroshi Fujita
d2586278d6 Feat elasticsearch japanese (#12194) 2025-01-08 12:35:41 +08:00
Wu Tianwei
6635c393e9 fix: adjust opacity for model selector based on readonly state (#12472) 2025-01-08 12:11:45 +08:00
crazywoola
6222179a57 Revert "fix:deepseek tool call not working correctly" (#12463) 2025-01-08 10:50:34 +08:00
Jyong
05bda6f38d add tidb on qdrant redis lock (#12462) 2025-01-08 08:55:44 +08:00
Hiroshi Fujita
4295cefeb1 fix: allow fallback to remote_url when url is not provided (#12455) 2025-01-07 22:33:25 +08:00
非法操作
67228c9b26 fix: url with variable not work (#12452) 2025-01-07 21:55:51 +08:00
Jyong
fd2bfff023 remove knowledge admin role (#12450) 2025-01-07 21:30:23 +08:00
Infinitnet
4e6c86341d Add 'document' feature to Sonnet 3.5 through OpenRouter (#12444) 2025-01-07 19:51:38 +08:00
ybalbert001
2a14c67edc Fix #12448 - update bedrock retrieve tool, support hybrid search type and re… (#12446)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2025-01-07 19:51:23 +08:00
-LAN-
c236f05f4b chore: bump version to 0.15.0 (#12297)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-07 18:05:14 +08:00
-LAN-
0eeacdc80c refactor: enhance API token validation with session locking and last used timestamp update (#12426)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-07 18:04:41 +08:00
hisir
41f39bf3fc Fix newline characters in tables during document parsing (#12112)
Co-authored-by: hisir <admin@qq.com>
2025-01-07 17:26:24 +08:00
呆萌闷油瓶
9677144015 fix:deepseek tool call not working correctly (#12437) 2025-01-07 17:25:38 +08:00
SiliconFlow, Inc
15797c556f add fish-speech-1.5 from siliconflow (#12425) 2025-01-07 15:27:34 +08:00
-LAN-
acacf35a2a chore(docker/.env.example): Add TOP_K_MAX_VALUE to the .env.example… (#12422)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-07 14:51:16 +08:00
-LAN-
d3f5b1cbb6 refactor: use tiktoken for token calculation (#12416)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-07 13:32:30 +08:00
why
196ed8101b fix: [PromptEditorHeightResizeWrap] Bug #12410 (#12406) 2025-01-07 12:21:54 +08:00
SiliconFlow, Inc
dc650c5368 Fixes #12414: Add cheaper model and long context model for Qwen2.5-72B-Instruct from siliconflow (#12415) 2025-01-07 11:28:24 +08:00
Alex Chen
2bb521b135 Support TTS and Speech2Text for Model Provider GPUStack (#12381) 2025-01-07 09:42:11 +08:00
SiliconFlow, Inc
409cc7d9b0 mark deprecated models in siliconflow #12399 (#12405)
Co-authored-by: crazywoola <427733928@qq.com>
2025-01-07 09:08:58 +08:00
yihong
fe26be2312 fix: http method can be upper case and lower case close #11877 (#12401)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-01-06 20:35:53 +08:00
Wu Tianwei
34519de3b7 fix: enhance chunk list management with new invalidation keys and imp… (#12396)
ok
2025-01-06 18:22:16 +08:00
Warren Chen
147d578922 [Fix] revert sagemaker llm to support model hub (#12378) 2025-01-06 18:01:45 +08:00
Jyong
9c317b64c3 sandbox doesn't provide auto disable log (#12388) 2025-01-06 15:57:13 +08:00
Joel
3b8f6233b0 feat: support config top max value by env (#12375) 2025-01-06 10:38:14 +08:00
Hash Brown
455b0cd696 chore: chat app textarea auto focus (#12366) 2025-01-05 21:25:00 +08:00
eux
1fa66405c5 feat: support configuration of refresh token expiration by environment variable (#12335) 2025-01-04 11:56:44 +08:00
Wood
b680a85b57 fix: resolve issue with the opening statement generated by the AutomaticRes component failing to sync between states. (#12349) 2025-01-04 11:56:11 +08:00
Wood
682ebc5f64 Fix the issue where TextGeneration component does not correctly clear input data. (#12351) 2025-01-04 11:55:55 +08:00
Wood
b8ba39dfae Bugfix/style and i18n fixes (#12350) 2025-01-04 11:52:13 +08:00
Wood
6c9e6a3a5a fix: fix issue with chat-input-area clearing during Responding state. (#12352) 2025-01-04 11:51:35 +08:00
huangzhuo1949
70698024f5 fix: empty delete bug (#12339)
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
2025-01-03 20:46:39 +08:00
方程
6df17a334c fix: Update the API call address for the text_embedding model (#12342)
Co-authored-by: 方程 <fangcheng@oschina.cn>
2025-01-03 19:19:17 +08:00
zhu-an
a5fb59b17f fix: Encode Chinese characters with Unicode before querying to match the Unicode encoded Chinese characters in the db (#12337)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2025-01-03 19:12:48 +08:00
-LAN-
7ed6485f86 refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes (#12326)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-03 18:41:44 +08:00
Timmy-web
478150e850 chore: fix typo in zh-Hant localization (#12329) 2025-01-03 17:38:30 +08:00
jifei
3c2e30f348 fix: #12143 support streaming mode content start with "data:" (#12171) 2025-01-03 16:33:37 +08:00
Jyong
b873e6349c add child chunk preview number limit (#12309) 2025-01-03 16:14:27 +08:00
Shun Miyazawa
2b1a32fd9c feat: Add filter to show only apps created by the user (#11968) 2025-01-03 15:38:36 +08:00
Nam Vu
a2105634a4 chore: change app card layout follow by #10847 (#12317) 2025-01-03 14:16:17 +08:00
yagiyuki
7c71bd7be7 doc: Added explanation of chunk_overlap to knowledge API (#12247)
Co-authored-by: crazywoola <427733928@qq.com>
2025-01-03 10:02:17 +08:00
丹枫染秋色
7c1961e618 feat: Add response format support to GLM-4 (#12252) 2025-01-03 09:38:50 +08:00
xander-art
baeddd4d15 feat:Add support for stop parameter in hunyuan model #12313 (#12315)
Co-authored-by: xander-art <xander-art@gmail.com>
2025-01-03 09:15:04 +08:00
-LAN-
6f5a8a33d9 refactor: replace gevent threadpool with ProcessPoolExecutor in GPT2Tokenizer (#12316)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-03 09:13:18 +08:00
-LAN-
52b2559a14 fix(app.py): if condition (#12314)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-03 01:36:23 +08:00
yihong
3d150c30a7 fix: utcfromtimestamp is Deprecated change to new api (#12120)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-01-02 18:52:36 +08:00
Wu Tianwei
e58e573f3e fix: add full doc mode preview length limit (#12310)
ok
2025-01-02 18:36:49 +08:00
-LAN-
375aa38f5d fix: improve content decoding in AppDslService (#12304)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-02 16:54:46 +08:00
-LAN-
0e6317678f Fix code scanning alert no. 111: Incomplete URL substring sanitization (#12305)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-01-02 16:52:43 +08:00
dependabot[bot]
e7dffcd0f6 chore(deps): bump aiohttp from 3.10.5 to 3.10.11 in /api (#12303)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-02 16:42:47 +08:00
Hanqing Zhao
065304d175 Modify jp translation (#12292) 2025-01-02 16:28:27 +08:00
-LAN-
15f43dd326 chore(deps): update yarl version from 1.9.4 to 1.18.3 (#12302)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-02 16:26:47 +08:00
Wu Tianwei
09d759d196 fix: Fix parent child retrieval issues (#12206)
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
2025-01-02 16:07:21 +08:00
dependabot[bot]
68757950ce chore(deps): bump jinja2 from 3.1.4 to 3.1.5 in /api (#12300)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-02 16:04:10 +08:00
-LAN-
3c45bdf18a fix: disable gevent in debug mode for better compatibility with JetBrains Python debugger (#12299)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-02 15:59:30 +08:00
Yingchun Lai
c135967e59 refactor: simplify some exception catch code (#12246) 2025-01-02 14:25:12 +08:00
Sean Sun
f71af7c2a8 fix: DocumentAddByFileApi miss data_source_type field but there is a mandatory value check (#12273) 2025-01-02 14:24:15 +08:00
github-actions[bot]
5b01eb9437 chore: translate i18n files (#12208)
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
2025-01-02 10:04:43 +08:00
呆萌闷油瓶
2e716f80d2 fix:The chart of average interaction counts per conversation show not… (#12199) 2025-01-02 10:02:56 +08:00
Giovanny Gutiérrez
d7c0bc8c23 feat: Add response format support for openai compat models (#12240)
Co-authored-by: Gio Gutierrez <giovannygutierrez@gmail.com>
2025-01-02 09:59:34 +08:00
yihong
f30bf08580 fix: close #12215 for yi special case (#12222)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-01-02 09:58:34 +08:00
-LAN-
a640803fc9 fix(models): use bigint on workflow_runs.total_tokens (#12279)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-01-02 09:49:34 +08:00
Warren Chen
9954ddb780 [Fix] modify sagemaker llm (#12274) 2025-01-02 09:49:11 +08:00
非法操作
b218df6920 fix: draft run single node can't get env variable (#12266) 2025-01-01 13:31:44 +08:00
-LAN-
5b6950e545 fix: improve error handling in NotionOAuth for block parent page ID r… (#12268)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 17:03:26 +08:00
-LAN-
c7911c7130 fix: improve JSON parsing error handling in Executor class (#12265)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 17:03:07 +08:00
-LAN-
62f792ea14 fix: ensure workflow_run_id is always set and improve handling in Wor… (#12264)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 17:02:44 +08:00
-LAN-
6a85960605 feat: implement asynchronous token counting in GPT2Tokenizer (#12239)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 17:02:08 +08:00
-LAN-
63a0b8ba79 feat: integrate psycogreen for gevent compatibility in PostgreSQL (#12253)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 14:45:59 +08:00
-LAN-
634b382a3d fix: enhance ToolEngineInvokeError to include meta information (#12238)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 14:01:24 +08:00
Han K
fbf5deda21 fix(env): docker compose variable interpolation issue for COMPOSE_PRO… (#12093)
Co-authored-by: Han Kyaw <hankyaw@Hans-MBP.lan>
2024-12-31 13:46:51 +08:00
-LAN-
d4b848272e fix: apply gevent threading patch early and ensure unique workflow node execution IDs (#12196)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-31 11:42:51 +08:00
Warren Chen
fc29f2003e translate comments (#12234) 2024-12-31 00:36:03 +08:00
Bowen Liang
ab469aa07d feat: support opening new tab in markdown button (#12213) 2024-12-30 22:27:25 +08:00
Warren Chen
562450751f [Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#12227) 2024-12-30 22:26:04 +08:00
308 changed files with 5137 additions and 3100 deletions

View File

@@ -82,6 +82,33 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter

View File

@@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

View File

@@ -1,12 +1,8 @@
from libs import version_utils
# preparation before creating app
version_utils.check_supported_python_version()
import os
import sys
def is_db_command():
import sys
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False
@@ -18,10 +14,25 @@ if is_db_command():
app = create_migrations_app()
else:
from app_factory import create_app
from libs import threadings_utils
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore
threadings_utils.apply_gevent_threading_patch()
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]

View File

@@ -488,6 +488,11 @@ class AuthConfig(BaseSettings):
default=60,
)
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
@@ -667,6 +672,11 @@ class IndexingConfig(BaseSettings):
default=4000,
)
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(

View File

@@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.14.2",
default="0.15.0",
)
COMMIT_SHA: str = Field(

View File

@@ -57,12 +57,13 @@ class AppListApi(Resource):
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}

View File

@@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -273,8 +273,7 @@ FROM
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)

View File

@@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
@@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM

View File

@@ -257,7 +257,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)

View File

@@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper

View File

@@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_user

View File

@@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@@ -74,7 +73,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@@ -133,7 +132,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
user=current_user,
source="datasets",
)
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
@@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Optional
@@ -8,6 +8,8 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
@@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
return decorator
def validate_and_get_api_token(scope=None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
"""
@@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
)
.first()
)
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()
if not api_token:
raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token

View File

@@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value

View File

@@ -14,7 +14,11 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import App, AppMode, EndUser

View File

@@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db
@@ -336,7 +336,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@@ -67,24 +67,17 @@ from models.account import Account
from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
class AdvancedChatAppGenerateTaskPipeline:
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
@@ -96,7 +89,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool,
dialogue_count: int,
) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
@@ -113,32 +106,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._message_cycle_manager = MessageCycleManage(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._conversation_name_generate_thread = None
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = ""
self._workflow_run_id: str = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@@ -146,13 +142,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@@ -269,24 +265,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
with Session(db.engine, expire_on_commit=False) as session:
err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit()
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
@@ -297,7 +295,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response(
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@@ -310,12 +308,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -329,13 +329,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
node_start_resp = self._workflow_node_start_to_stream_response(
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -348,12 +350,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -364,10 +370,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if node_finish_resp:
yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -381,13 +389,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
@@ -395,13 +407,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
@@ -409,9 +425,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -423,9 +441,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -437,9 +457,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -454,8 +476,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -466,21 +488,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -491,21 +515,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=None,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -517,20 +543,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit()
yield workflow_finish_resp
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -541,7 +569,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -555,18 +583,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager._handle_retriever_resources(event)
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._message_cycle_manager._handle_annotation_reply(event)
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@@ -587,23 +615,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(
yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)
# Save message
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
@@ -621,7 +653,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
@@ -685,20 +717,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
return False

View File

@@ -18,7 +18,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@@ -245,7 +245,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@@ -221,6 +221,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -270,7 +271,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from typing import Optional, Union
from sqlalchemy.orm import Session
@@ -58,7 +58,6 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@@ -66,16 +65,11 @@ from models.workflow import (
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
class WorkflowAppGenerateTaskPipeline:
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
@@ -84,7 +78,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser],
stream: bool,
) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
@@ -101,19 +95,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -122,7 +118,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@@ -237,29 +233,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
graph_runtime_state = None
for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response(
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@@ -271,12 +267,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -290,12 +288,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_node_start_to_stream_response(
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -306,9 +306,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
node_success_response = self._workflow_node_finish_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -319,12 +321,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_success_response:
yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session,
event=event,
)
node_failed_response = self._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@@ -339,13 +341,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
@@ -354,13 +360,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
@@ -369,9 +379,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -384,9 +396,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -399,9 +413,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -416,8 +432,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -431,7 +447,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@@ -445,8 +461,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -461,7 +477,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@@ -473,8 +489,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@@ -492,7 +508,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()

View File

@@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
# app config
app_config: WorkflowUIBasedAppConfig
workflow_run_id: Optional[str] = None
workflow_run_id: str
class SingleIterationRunEntity(BaseModel):
"""

View File

@@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import (
ErrorStreamResponse,
PingStreamResponse,
TaskState,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._start_at = time.perf_counter()

View File

@@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""

View File

@@ -34,7 +34,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -58,13 +57,20 @@ from models.workflow import (
WorkflowRunStatus,
)
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start(
self,
@@ -102,7 +108,8 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id
@@ -239,7 +246,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where(
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@@ -247,16 +254,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
running_workflow_node_executions = session.scalars(stmt).all()
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
if trace_manager:
trace_manager.add_trace_task(
@@ -274,7 +283,7 @@ class WorkflowCycleManage:
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = event.node_execution_id
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
@@ -298,6 +307,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(
@@ -325,6 +336,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
@@ -364,6 +376,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
@@ -391,7 +404,7 @@ class WorkflowCycleManage:
execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = event.node_execution_id
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
@@ -415,6 +428,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
#################################################
@@ -811,22 +826,20 @@ class WorkflowCycleManage:
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id)
workflow_node_execution = session.scalar(stmt)
if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
return workflow_node_execution
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution

View File

@@ -1,9 +1,6 @@
from os.path import abspath, dirname, join
from threading import Lock
from typing import Any
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
_tokenizer: Any = None
_lock = Lock()
@@ -15,11 +12,16 @@ class GPT2Tokenizer:
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text, verbose=False)
tokens = _tokenizer.encode(text)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
@@ -27,8 +29,19 @@ class GPT2Tokenizer:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@@ -24,8 +24,5 @@ class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
if model is None:
model = "bge-m3"
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/"
def _add_custom_parameters(credentials: dict, model: str) -> None:
credentials["endpoint_url"] = "https://ai.gitee.com/v1"

View File

@@ -9,6 +9,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
@@ -118,3 +120,19 @@ model_credential_schema:
label:
en_US: Not Support
zh_Hans: 不支持
- variable: voices
show_on:
- variable: __model_type
value: tts
label:
en_US: Available Voices (comma-separated)
zh_Hans: 可用声音(用英文逗号分隔)
type: text-input
required: false
default: "Chinese Female"
placeholder:
en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
help:
en_US: "List voice names separated by commas. First voice will be used as default."
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"

View File

@@ -1,7 +1,5 @@
from collections.abc import Generator
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
@@ -24,9 +22,10 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(
model,
credentials,
compatible_credentials,
prompt_messages,
model_parameters,
tools,
@@ -36,10 +35,15 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)
def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"

View File

@@ -0,0 +1,43 @@
from typing import IO, Optional
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel):
"""
Model class for GPUStack Speech to text model.
"""
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(model, compatible_credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
"""
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)
def _get_compatible_credentials(self, credentials: dict) -> dict:
"""
Get compatible credentials
:param credentials: model credentials
:return: compatible credentials
"""
compatible_credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
return compatible_credentials

View File

@@ -1,7 +1,5 @@
from typing import Optional
from yarl import URL
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
@@ -24,12 +22,15 @@ class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(model, compatible_credentials, texts, user, input_type)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials

View File

@@ -0,0 +1,57 @@
from typing import Any, Optional
from core.model_runtime.model_providers.openai_api_compatible.tts.tts import OAICompatText2SpeechModel
class GPUStackText2SpeechModel(OAICompatText2SpeechModel):
"""
Model class for GPUStack Text to Speech model.
"""
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> Any:
"""
Invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param user: unique user id
:return: text translated to audio file
"""
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(
model=model,
tenant_id=tenant_id,
credentials=compatible_credentials,
content_text=content_text,
voice=voice,
user=user,
)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:param user: unique user id
"""
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)
def _get_compatible_credentials(self, credentials: dict) -> dict:
"""
Get compatible credentials
:param credentials: model credentials
:return: compatible credentials
"""
compatible_credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
return compatible_credentials

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -6,6 +6,7 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
@@ -19,6 +20,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
@@ -18,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -19,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -19,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -19,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -19,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.1'

View File

@@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
@@ -18,6 +19,18 @@ parameter_rules:
default: 1024
min: 1
max: 32768
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "0.05"
output: "0.1"

View File

@@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
@@ -18,6 +19,18 @@ parameter_rules:
default: 1024
min: 1
max: 32768
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "0.05"
output: "0.1"

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.20'
output: '0.20'

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.7'
output: '0.8'

View File

@@ -18,6 +18,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.59'
output: '0.79'

View File

@@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 8192
@@ -18,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.08'

View File

@@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 8192
@@ -18,6 +19,18 @@ parameter_rules:
default: 512
min: 1
max: 8192
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.08'

View File

@@ -54,6 +54,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
"Model": model,
"Messages": messages_dict,
"Stream": stream,
"Stop": stop,
**custom_parameters,
}
# add Tools and ToolChoice

View File

@@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().removeprefix("data: ")
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)

View File

@@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '2.50'
output: '10.00'

View File

@@ -739,6 +739,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
delta = chunk.choices[0]
has_finish_reason = delta.finish_reason is not None
# to fix issue #12215 yi model has special case for ligthing
# FIXME drop the case when yi model is updated
if model.startswith("yi-"):
if isinstance(delta.finish_reason, str):
# doc: https://platform.lingyiwanwu.com/docs/api-reference
has_finish_reason = delta.finish_reason.startswith(("length", "stop", "content_filter"))
if (
not has_finish_reason

View File

@@ -332,6 +332,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
if not endpoint_url.endswith("/"):
endpoint_url += "/"
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
elif "json_schema" in model_parameters:
del model_parameters["json_schema"]
data = {"model": model, "stream": stream, **model_parameters}
completion_type = LLMMode.value_of(credentials["mode"])
@@ -462,7 +479,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().removeprefix("data: ")
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue

View File

@@ -7,6 +7,7 @@ features:
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 200000

View File

@@ -1,4 +1,3 @@
- Tencent/Hunyuan-A52B-Instruct
- Qwen/QwQ-32B-Preview
- Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2.5-32B-Instruct
@@ -6,11 +5,11 @@
- Qwen/Qwen2.5-7B-Instruct
- Qwen/Qwen2.5-Coder-32B-Instruct
- Qwen/Qwen2.5-Coder-7B-Instruct
- Qwen/Qwen2.5-Math-72B-Instruct
- Qwen/Qwen2-VL-72B-Instruct
- Qwen/Qwen2-1.5B-Instruct
- Qwen/Qwen2.5-72B-Instruct-128K
- Vendor-A/Qwen/Qwen2.5-72B-Instruct
- Pro/Qwen/Qwen2-VL-7B-Instruct
- OpenGVLab/InternVL2-Llama3-76B
- OpenGVLab/InternVL2-26B
- Pro/OpenGVLab/InternVL2-8B
- deepseek-ai/DeepSeek-V2.5

View File

@@ -82,3 +82,4 @@ pricing:
output: '21'
unit: '0.000001'
currency: RMB
deprecated: true

View File

@@ -82,3 +82,4 @@ pricing:
output: '21'
unit: '0.000001'
currency: RMB
deprecated: true

View File

@@ -0,0 +1,54 @@
model: Qwen/QVQ-72B-Preview
label:
en_US: Qwen/QVQ-72B-Preview
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 8192
min: 1
max: 16384
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '9.90'
output: '9.90'
unit: '0.000001'
currency: RMB

View File

@@ -15,9 +15,9 @@ parameter_rules:
- name: max_tokens
use_template: max_tokens
type: int
default: 512
default: 4096
min: 1
max: 4096
max: 8192
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.

View File

@@ -78,7 +78,7 @@ parameter_rules:
- text
- json_object
pricing:
input: '21'
output: '21'
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@@ -78,7 +78,7 @@ parameter_rules:
- text
- json_object
pricing:
input: '21'
output: '21'
input: '0.35'
output: '0.35'
unit: '0.000001'
currency: RMB

View File

@@ -0,0 +1,51 @@
model: Qwen/Qwen2.5-72B-Instruct-128K
label:
en_US: Qwen/Qwen2.5-72B-Instruct-128K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@@ -0,0 +1,51 @@
model: Vendor-A/Qwen/Qwen2.5-72B-Instruct
label:
en_US: Vendor-A/Qwen/Qwen2.5-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '1.00'
output: '1.00'
unit: '0.000001'
currency: RMB

View File

@@ -15,7 +15,7 @@ parameter_rules:
type: int
default: 512
min: 1
max: 8192
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.

View File

@@ -82,3 +82,4 @@ pricing:
output: '4.13'
unit: '0.000001'
currency: RMB
deprecated: true

View File

@@ -0,0 +1,37 @@
model: fishaudio/fish-speech-1.5
model_type: tts
model_properties:
default_voice: 'fishaudio/fish-speech-1.5:alex'
voices:
- mode: "fishaudio/fish-speech-1.5:alex"
name: "Alex男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:benjamin"
name: "Benjamin男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:charles"
name: "Charles男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:david"
name: "David男声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:anna"
name: "Anna女声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:bella"
name: "Bella女声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:claire"
name: "Claire女声"
language: [ "zh-Hans", "en-US" ]
- mode: "fishaudio/fish-speech-1.5:diana"
name: "Diana女声"
language: [ "zh-Hans", "en-US" ]
audio_type: 'mp3'
max_workers: 5
# stream: false
pricing:
input: '0.015'
output: '0'
unit: '0.001'
currency: RMB

View File

@@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().removeprefix("data: ")
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.1'
output: '0.1'

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.001'
output: '0.001'

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.01'
output: '0.01'

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0'
output: '0'

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0'
output: '0'

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.1'
output: '0.1'

View File

@@ -49,6 +49,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.001'
output: '0.001'

View File

@@ -47,6 +47,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.05'

View File

@@ -45,6 +45,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.05'
output: '0.05'

View File

@@ -45,6 +45,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.00'
output: '0.00'

View File

@@ -46,6 +46,18 @@ parameter_rules:
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.01'
output: '0.01'

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Generator
from typing import Optional, Union
@@ -188,6 +189,23 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else:
model_parameters["tools"] = [web_search_params]
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
elif "json_schema" in model_parameters:
del model_parameters["json_schema"]
if model.startswith("glm-4v"):
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:

View File

@@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

View File

@@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids)

View File

@@ -0,0 +1,104 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)

View File

@@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@@ -2,6 +2,7 @@ import json
import logging
from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
@@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
"""
Configuration class for Milvus connection.
"""
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
@@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
return values
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
@@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields: list[str] = []
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
@@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
@@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
@@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False
@@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
"""
Search for documents by vector similarity.
"""
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
@@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
@@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection
collection_name = self._collection_name
self._client.create_collection(
collection_name=collection_name,
collection_name=self._collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
@@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)

View File

@@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:

View File

@@ -167,6 +167,8 @@ class OracleVector(BaseVector):
return docs
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))

View File

@@ -129,6 +129,11 @@ class PGVector(BaseVector):
return docs
def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

View File

@@ -140,6 +140,8 @@ class TencentVector(BaseVector):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:

View File

@@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
@@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

Some files were not shown because too many files have changed in this diff Show More