Compare commits

..

58 Commits

Author SHA1 Message Date
Novice Lee
fde3fe0ab6 fix: reformat the http node file 2024-12-20 13:15:44 +08:00
Novice Lee
07528f82b9 Merge branch 'main' into feat/node-execution-retry 2024-12-20 11:21:53 +08:00
Dr.MerdanBay
bb2f46d7cc fix: add safe dictionary access for bedrock credentials (#11860) 2024-12-20 12:13:39 +09:00
yihong
463fbe2680 fix: better gard nan value from numpy for issue #11827 (#11864)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-20 09:28:32 +08:00
傻笑zz
95a7e50137 Fix comfyui tool https (#11859) 2024-12-20 09:27:21 +08:00
非法操作
9d93ad1f16 feat: add gemini-2.0-flash-thinking-exp-1219 (#11863) 2024-12-20 09:26:31 +08:00
stardust
44104797d6 fix: Enhance file type detection in HTTP Request node (#11797)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: 谭成 <tancheng.sh@chinatelecom.cn>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-20 02:21:41 +08:00
傻笑zz
1548501050 fix: comfyui tool supports https (#11823) 2024-12-19 23:05:27 +08:00
crazywoola
de3911e930 Fix/10584 wrong message when no custom tool available in custom tool list (#11851) 2024-12-19 21:19:08 +08:00
yihong
5a8a901560 fix: float values are not json for nan value close #11827 (#11840)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-19 20:50:20 +08:00
yihong
12d45e9114 fix: silicon change its model fix #11844 (#11847)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-19 20:50:09 +08:00
barabicu
d057067543 fix: remove ruff ignore SIM300 (#11810) 2024-12-19 18:30:51 +08:00
sino
560d375e0f feat(ark): add doubao-pro-256k and doubao-embedding-large (#11831) 2024-12-19 17:49:31 +08:00
Novice Lee
127291a90f feat: add single step retry 2024-12-19 17:03:05 +08:00
Novice Lee
9e0c28791d fix: resolve code merge issues 2024-12-19 14:46:19 +08:00
Agung Besti
3388d6636c add-model-azure-gpt-4o-2024-11-20 (#11803)
Co-authored-by: agungbesti <agung.besti@insignia.co.id>
2024-12-19 12:36:11 +08:00
Charlie.Wei
2624a6dcd0 Fix explore app icon (#11808)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-12-18 21:24:21 +08:00
yihong
b5c2785e10 ci: fix config ci and it works (#11807)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-18 20:17:10 +08:00
yihong
493834d45d ci: add config ci more disscuss check #11706 (#11752)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-18 17:36:36 +08:00
-LAN-
926546b153 chore: bump version to 0.14.1 (#11784)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-18 16:35:54 +08:00
Novice Lee
b411087bb7 Merge branch 'main' into feat/node-execution-retry 2024-12-18 15:33:24 +08:00
Novice Lee
357769c72e feat: handle http node retry 2024-12-18 15:30:14 +08:00
xander-art
56434db4f5 feat:add hunyuan model(hunyuan-role, hunyuan-large, hunyuan-large-rol… (#11766)
Co-authored-by: xanderdong <xanderdong@tencent.com>
2024-12-18 15:25:53 +08:00
-LAN-
688292e6ff chore(opendal_storage): remove unused comment (#11783)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-18 15:20:54 +08:00
Shun Miyazawa
f7415e1ca4 feat: Disable the "Forgot your password?" button when the mail server setup is incomplete (#11653) 2024-12-18 15:20:41 +08:00
-LAN-
2961fa0e08 chore(.env.example): add comments for opendal (#11778)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-18 15:20:03 +08:00
Jiang
ad17ff9a92 Lindorm vdb bug-fix (#11790)
Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
2024-12-18 15:19:20 +08:00
Benjamin
558ab25f51 fix: imperfect service-api introduction text (#11782) 2024-12-18 13:43:34 +08:00
-LAN-
a5db7c9acb feat: add openai o1 & update pricing and max_token of other models (#11780)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-18 12:15:11 +08:00
Joe
580297e290 fix: file upload auth (#11774) 2024-12-18 11:02:40 +08:00
Novice Lee
853b9af09c Merge branch 'main' into feat/node-execution-retry 2024-12-18 09:38:18 +08:00
DDDDD12138
79d11ea709 feat: add parameters for JinaReaderTool (#11613) 2024-12-18 09:08:06 +08:00
-LAN-
99f40a9682 feat: full support for opendal and sync configurations between .env and docker-compose (#11754)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-18 09:05:54 +08:00
-LAN-
e86756cb39 feat(app_factory): speed up api startup (#11762)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-18 09:05:31 +08:00
barabicu
1325246da8 fix: Prevent redirection to /overview when accessing /workflow. (#11733) 2024-12-18 08:37:22 +08:00
Hiroshi Fujita
dfa9a91906 (doc) fix: update cURL examples to include Authorization header (#11750) 2024-12-17 17:44:40 +08:00
Charlie.Wei
5e2926a207 Fix explore app icon (#11742)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-12-17 17:42:44 +08:00
非法操作
9048832a9a chore: improve gemini models (#11745) 2024-12-17 17:42:21 +08:00
Shota Totsuka
7d5a385811 feat: use Gemini response metadata for token counting (#11743) 2024-12-17 17:42:05 +08:00
Novice Lee
b99f1a09f4 feat: workflow node support retry 2024-12-17 16:50:07 +08:00
-LAN-
900e93f758 chore: update comments in docker env file (#11705)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-17 15:45:00 +08:00
sino
99430a5931 feat(ark): support doubao vision series models (#11740) 2024-12-17 15:43:11 +08:00
非法操作
c9b4029ce7 chore: the consistency of MultiModalPromptMessageContent (#11721) 2024-12-17 15:01:38 +08:00
Bowen Liang
78c3051585 fix: make tidb service optional with proper profile in docker compose yaml (#11729) 2024-12-17 14:25:15 +08:00
呆萌闷油瓶
cd4310df25 chore:update azure api version (#11711) 2024-12-17 13:39:56 +08:00
-LAN-
259cff9f22 fix(api/ops_trace): avoid raise exception directly (#11732)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-17 13:38:57 +08:00
Hanqing Zhao
7b7eb00385 Modify translation for error branch (#11731) 2024-12-17 13:28:13 +08:00
-LAN-
62b9e5a6f9 feat(knowledge_retrieval_node): Suppress exceptions thrown by DatasetRetrieval (#11728)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-17 13:12:29 +08:00
NFish
a399502ecd Dark Mode: Workflow darkmode style (#11695) 2024-12-17 12:20:49 +08:00
-LAN-
92a840f1b2 feat(tool_node): Suppress exceptions thrown by the Tool (#11724)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-17 12:11:50 +08:00
非法操作
74fdc16bd1 feat: enhance gemini models (#11497) 2024-12-17 12:05:13 +08:00
yihong
56cfdce453 chore: update docker env close #11703 (#11706)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-17 09:01:23 +08:00
yihong
efa8eb379f fix: memory leak by pypdfium2 close(maybe) #11510 (#11700)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-17 00:42:01 +08:00
crazywoola
7f095bdc42 fix: image icon can not display (#11701) 2024-12-16 19:15:23 +08:00
Kazuhisa Wada
e20161b3de make login lockout duration configurable (#11699) 2024-12-16 19:05:27 +08:00
方程
fc8fdbacb4 feat: add gitee ai vl models (#11697)
Co-authored-by: 方程 <fangcheng@oschina.cn>
2024-12-16 18:45:26 +08:00
longfengpili
7fde638556 fix: fix proxy for docker (#11681) 2024-12-16 18:43:59 +08:00
非法操作
be93c19b7e chore: remove duplicate folder with case sensitivity issue (#11687) 2024-12-16 17:59:00 +08:00
200 changed files with 3484 additions and 1163 deletions

View File

@@ -50,6 +50,9 @@ jobs:
- name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -C api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh

View File

@@ -9,5 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"

View File

@@ -60,17 +60,8 @@ DB_DATABASE=dify
STORAGE_TYPE=opendal
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
STORAGE_OPENDAL_SCHEME=fs
# OpenDAL FS
OPENDAL_SCHEME=fs
OPENDAL_FS_ROOT=storage
# OpenDAL S3
OPENDAL_S3_ROOT=/
OPENDAL_S3_BUCKET=your-bucket-name
OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com
OPENDAL_S3_ACCESS_KEY_ID=your-access-key
OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key
OPENDAL_S3_REGION=your-region
OPENDAL_S3_SERVER_SIDE_ENCRYPTION=
# S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false
@@ -313,8 +304,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
@@ -435,3 +425,5 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400

View File

@@ -70,7 +70,6 @@ ignore = [
"SIM113", # eumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"SIM300", # yoda-conditions,
]
[lint.per-file-ignores]

View File

@@ -1,13 +1,30 @@
from app_factory import create_app
from libs import threadings_utils, version_utils
from libs import version_utils
# preparation before creating app
version_utils.check_supported_python_version()
threadings_utils.apply_gevent_threading_patch()
def is_db_command():
import sys
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False
# create app
app = create_app()
celery = app.extensions["celery"]
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
else:
from app_factory import create_app
from libs import threadings_utils
threadings_utils.apply_gevent_threading_patch()
app = create_app()
celery = app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@@ -1,5 +1,4 @@
import logging
import os
import time
from configs import dify_config
@@ -17,15 +16,6 @@ def create_flask_app_with_configs() -> DifyApp:
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
# populate configs into system environment variables
for key, value in dify_app.config.items():
if isinstance(value, str):
os.environ[key] = value
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ""
return dify_app
@@ -98,3 +88,14 @@ def initialize_extensions(app: DifyApp):
end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
def create_migrations_app():
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
return app

View File

@@ -485,6 +485,11 @@ class AuthConfig(BaseSettings):
default=60,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""
@@ -660,14 +665,9 @@ class IndexingConfig(BaseSettings):
)
class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
@@ -773,13 +773,13 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,

View File

@@ -1,51 +1,9 @@
from enum import StrEnum
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
class OpenDALScheme(StrEnum):
FS = "fs"
S3 = "s3"
class OpenDALStorageConfig(BaseSettings):
STORAGE_OPENDAL_SCHEME: str = Field(
default=OpenDALScheme.FS.value,
OPENDAL_SCHEME: str = Field(
default="fs",
description="OpenDAL scheme.",
)
# FS
OPENDAL_FS_ROOT: str = Field(
default="storage",
description="Root path for local storage.",
)
# S3
OPENDAL_S3_ROOT: str = Field(
default="/",
description="Root path for S3 storage.",
)
OPENDAL_S3_BUCKET: str = Field(
default="",
description="S3 bucket name.",
)
OPENDAL_S3_ENDPOINT: str = Field(
default="https://s3.amazonaws.com",
description="S3 endpoint URL.",
)
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
default="",
description="S3 access key ID.",
)
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
default="",
description="S3 secret access key.",
)
OPENDAL_S3_REGION: str = Field(
default="",
description="S3 region.",
)
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
default="",
description="S3 server-side encryption.",
)

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.14.0",
default="0.14.1",
)
COMMIT_SHA: str = Field(

View File

@@ -31,7 +31,7 @@ def admin_required(view):
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if dify_config.ADMIN_API_KEY != auth_token:
if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
return view(*args, **kwargs)

View File

@@ -65,7 +65,7 @@ class ModelConfigResource(Resource):
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}",
)
except Exception as e:
except Exception:
continue
# get decrypted parameters
@@ -97,7 +97,7 @@ class ModelConfigResource(Resource):
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
except Exception as e:
except Exception:
continue
manager = ToolParameterConfigurationManager(

View File

@@ -1,4 +1,5 @@
from flask_restful import Resource, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
@@ -26,7 +27,7 @@ class TraceAppConfigApi(Resource):
return {"has_not_configured": True}
return trace_config
except Exception as e:
raise e
raise BadRequest(str(e))
@setup_required
@login_required
@@ -48,7 +49,7 @@ class TraceAppConfigApi(Resource):
raise TracingConfigCheckError()
return result
except Exception as e:
raise e
raise BadRequest(str(e))
@setup_required
@login_required
@@ -68,7 +69,7 @@ class TraceAppConfigApi(Resource):
raise TracingConfigNotExist()
return {"result": "success"}
except Exception as e:
raise e
raise BadRequest(str(e))
@setup_required
@login_required
@@ -85,7 +86,7 @@ class TraceAppConfigApi(Resource):
raise TracingConfigNotExist()
return {"result": "success"}
except Exception as e:
raise e
raise BadRequest(str(e))
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")

View File

@@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
if document.indexing_status == "completed":
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception as e:
except Exception:
logging.exception(f"Failed to retry document, document id: {document_id}")
continue
# retry document

View File

@@ -4,6 +4,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import login_required
from services.recommended_app_service import RecommendedAppService
@@ -12,6 +13,8 @@ app_fields = {
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}

View File

@@ -1,6 +1,7 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
from configs import dify_config
@@ -58,6 +59,9 @@ class FileApi(Resource):
if not file.filename:
raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden()
if source not in ("datasets", None):
source = None

View File

@@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):

View File

@@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response:
yield node_failed_response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

View File

@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
elif isinstance(event, NodeRunRetryEvent):
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
start_index=event.start_index,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""

View File

@@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
ERROR = "error"
PING = "ping"
STOP = "stop"
RETRY = "retry"
class AppQueueEvent(BaseModel):
@@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueNodeRetryEvent(AppQueueEvent):
"""QueueNodeRetryEvent entity"""
event: QueueEvent = QueueEvent.RETRY
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
retry_index: int # retry index
start_index: int # start index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity

View File

@@ -52,6 +52,7 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
@@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
}
class NodeRetryStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
"workflow_run_id": self.workflow_run_id,
"data": {
"id": self.data.id,
"node_id": self.data.node_id,
"node_type": self.data.node_type,
"title": self.data.title,
"index": self.data.index,
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"process_data": None,
"outputs": None,
"status": self.data.status,
"error": None,
"elapsed_time": self.data.elapsed_time,
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"retry_index": self.data.retry_index,
},
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity

View File

@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
@@ -271,9 +273,9 @@ class WorkflowCycleManage:
db.session.close()
with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)
# with Session(db.engine, expire_on_commit=False) as session:
# session.add(workflow_run)
# session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
@@ -423,6 +425,52 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.index = event.start_index
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
return workflow_node_execution
#################################################
# to stream responses #
#################################################
@@ -587,6 +635,51 @@ class WorkflowCycleManage:
),
)
def _workflow_node_retry_to_stream_response(
self,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:

View File

@@ -42,39 +42,31 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO:
encoded_string = _get_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _:
raise ValueError(f"file type {f.type} is not supported")
params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
try:
return prompt_class_map[f.type](**params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /):
@@ -128,11 +120,6 @@ def _get_encoded_string(f: File, /):
return encoded_string
def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:

View File

@@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:

View File

@@ -1,9 +1,9 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Literal, Optional
from typing import Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, computed_field, field_validator
class PromptMessageRole(Enum):
@@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
"""
type: PromptMessageContentType
data: str
class TextPromptMessageContent(PromptMessageContent):
@@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str
class VideoPromptMessageContent(PromptMessageContent):
class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""
type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent):
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")
class ImagePromptMessageContent(PromptMessageContent):
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
@@ -103,11 +116,8 @@ class ImagePromptMessageContent(PromptMessageContent):
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(PromptMessageContent):
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
class PromptMessage(ABC, BaseModel):

View File

@@ -1,5 +1,4 @@
import base64
import io
import json
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
@@ -18,7 +17,6 @@ from anthropic.types import (
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import (
@@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
if not message_content.base64_data:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
image_content = requests.get(message_content.url).content
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
base64_data = message_content.base64_data
mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
@@ -534,7 +529,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
},

View File

@@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
),
),
),
AzureBaseModel(
base_model_name="gpt-4o-2024-11-20",
entity=AIModelEntity(
model="fake-deployment-name",
label=I18nObject(
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.VISION,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=16384),
ParameterRule(
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
),
required=False,
options=["text", "json_object", "json_schema"],
),
ParameterRule(
name="json_schema",
label=I18nObject(en_US="JSON Schema"),
type="text",
help=I18nObject(
zh_Hans="设置返回的json schemallm将按照它返回",
en_US="Set a response json schema will ensure LLM to adhere it.",
),
required=False,
),
],
pricing=PriceConfig(
input=5.00,
output=15.00,
unit=0.000001,
currency="USD",
),
),
),
AzureBaseModel(
base_model_name="gpt-4-turbo",
entity=AIModelEntity(

View File

@@ -86,6 +86,9 @@ model_credential_schema:
- label:
en_US: '2024-06-01'
value: '2024-06-01'
- label:
en_US: '2024-10-21'
value: '2024-10-21'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
@@ -168,6 +171,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-11-20
value: gpt-4o-2024-11-20
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo

View File

@@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -1,11 +1,19 @@
from collections.abc import Mapping
import boto3
from botocore.config import Config
from core.model_runtime.errors.invoke import InvokeBadRequestError
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
region_name = credentials.get("aws_region")
if not region_name:
raise InvokeBadRequestError("aws_region is required")
client_config = Config(region_name=region_name)
aws_access_key_id = credentials.get("aws_access_key_id")
aws_secret_access_key = credentials.get("aws_secret_access_key")
def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(

View File

@@ -62,7 +62,10 @@ class BedrockRerankModel(RerankModel):
}
)
modelId = model
region = credentials["aws_region"]
region = credentials.get("aws_region")
# region is a required field
if not region:
raise InvokeBadRequestError("aws_region is required in credentials")
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",

View File

@@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -0,0 +1,93 @@
model: InternVL2-8B
label:
en_US: InternVL2-8B
model_type: llm
features:
- vision
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@@ -0,0 +1,93 @@
model: InternVL2.5-26B
label:
en_US: InternVL2.5-26B
model_type: llm
features:
- vision
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@@ -6,3 +6,5 @@
- deepseek-coder-33B-instruct-chat
- deepseek-coder-33B-instruct-completions
- codegeex4-all-9b
- InternVL2.5-26B
- InternVL2-8B

View File

@@ -29,18 +29,26 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials, model, model_parameters)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return super()._invoke(
GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model),
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials, None)
super().validate_credentials(model, credentials)
self._add_custom_parameters(credentials, model, None)
super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials)
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None:
if model is None:
model = "Qwen2-72B-Instruct"
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
if model.endswith("completions"):
credentials["mode"] = LLMMode.COMPLETION.value
else:

View File

@@ -1,3 +1,5 @@
- gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219
- gemini-1.5-pro
- gemini-1.5-pro-latest
- gemini-1.5-pro-001
@@ -11,6 +13,8 @@
- gemini-1.5-flash-exp-0827
- gemini-1.5-flash-8b-exp-0827
- gemini-1.5-flash-8b-exp-0924
- gemini-exp-1206
- gemini-exp-1121
- gemini-exp-1114
- gemini-pro
- gemini-pro-vision

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-1219
label:
en_US: Gemini 2.0 Flash Thinking Exp 1219
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -1,24 +1,23 @@
import base64
import io
import json
import os
import tempfile
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Optional, Union
import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
@@ -35,21 +34,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
from extensions.ext_redis import redis_client
class GoogleLargeLanguageModel(LargeLanguageModel):
@@ -201,29 +186,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
if stop:
config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
history = []
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
response = google_model.generate_content(
contents=history,
@@ -317,8 +290,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if hasattr(response, "usage_metadata") and response.usage_metadata:
prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@@ -346,7 +323,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
@@ -359,6 +336,40 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return message_text
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.base64_data:
file_content = base64.b64decode(message_content.base64_data)
temp_file.write(file_content)
else:
try:
response = requests.get(message_content.url)
response.raise_for_status()
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}")
temp_file.flush()
file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
try:
os.unlink(temp_file.name)
except PermissionError:
# windows may raise permission error
pass
return file
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
@@ -374,28 +385,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))
return glm_content
elif isinstance(message, AssistantPromptMessage):

View File

@@ -3,8 +3,8 @@ label:
zh_Hans: 腾讯混元
en_US: Hunyuan
description:
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro and hunyuan-lite.
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro 和 hunyuan-lite。
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall and hunyuan-lite.
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall 和 hunyuan-lite。
icon_small:
en_US: icon_s_en.png
icon_large:

View File

@@ -4,3 +4,10 @@
- hunyuan-pro
- hunyuan-turbo
- hunyuan-vision
- hunyuan-role
- hunyuan-large
- hunyuan-large-role
- hunyuan-large-longcontext
- hunyuan-turbo-latest
- hunyuan-turbo-vision
- hunyuan-functioncall

View File

@@ -0,0 +1,38 @@
model: hunyuan-functioncall
label:
zh_Hans: hunyuan-functioncall
en_US: hunyuan-functioncall
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.004'
output: '0.008'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,38 @@
model: hunyuan-large-longcontext
label:
zh_Hans: hunyuan-large-longcontext
en_US: hunyuan-large-longcontext
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 134000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 134000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.006'
output: '0.018'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,38 @@
model: hunyuan-large-role
label:
zh_Hans: hunyuan-large-role
en_US: hunyuan-large-role
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.004'
output: '0.008'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,38 @@
model: hunyuan-large
label:
zh_Hans: hunyuan-large
en_US: hunyuan-large
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.004'
output: '0.012'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,38 @@
model: hunyuan-role
label:
zh_Hans: hunyuan-role
en_US: hunyuan-role
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.004'
output: '0.008'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,38 @@
model: hunyuan-turbo-latest
label:
zh_Hans: hunyuan-turbo-latest
en_US: hunyuan-turbo-latest
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.015'
output: '0.05'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,39 @@
model: hunyuan-turbo-vision
label:
zh_Hans: hunyuan-turbo-vision
en_US: hunyuan-turbo-vision
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.08'
output: '0.08'
unit: '0.001'
currency: RMB

View File

@@ -1,4 +1,7 @@
- gpt-4o-audio-preview
- o1
- o1-2024-12-17
- o1-mini
- o1-mini-2024-09-12
- gpt-4
- gpt-4o
- gpt-4o-2024-05-13
@@ -7,10 +10,6 @@
- chatgpt-4o-latest
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- o1-preview
- o1-preview-2024-09-12
- o1-mini
- o1-mini-2024-09-12
- gpt-4-turbo
- gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview
@@ -25,4 +24,7 @@
- gpt-3.5-turbo-1106
- gpt-3.5-turbo-0613
- gpt-3.5-turbo-instruct
- gpt-4o-audio-preview
- o1-preview
- o1-preview-2024-09-12
- text-davinci-003

View File

@@ -22,7 +22,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 16384
- name: response_format

View File

@@ -22,9 +22,9 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 4096
max: 16384
- name: response_format
label:
zh_Hans: 回复格式

View File

@@ -22,7 +22,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 16384
- name: response_format

View File

@@ -22,7 +22,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 16384
- name: response_format

View File

@@ -22,9 +22,9 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 4096
max: 16384
- name: response_format
label:
zh_Hans: 回复格式

View File

@@ -22,7 +22,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 16384
- name: response_format

View File

@@ -22,7 +22,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 16384
- name: response_format

View File

@@ -22,9 +22,9 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
default: 16384
min: 1
max: 4096
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
@@ -38,7 +38,7 @@ parameter_rules:
- text
- json_object
pricing:
input: '5.00'
output: '15.00'
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, AudioPromptMessageContent):
data_split = message_content.data.split(";base64,")
base64_data = data_split[1]
sub_message_dict = {
"type": "input_audio",
"input_audio": {
"data": message_content.data,
"data": base64_data,
"format": message_content.format,
},
}

View File

@@ -0,0 +1,35 @@
model: o1-2024-12-17
label:
en_US: o1-2024-12-17
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 50000
min: 1
max: 50000
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '15.00'
output: '60.00'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,36 @@
model: o1
label:
zh_Hans: o1
en_US: o1
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: max_tokens
use_template: max_tokens
default: 50000
min: 1
max: 50000
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '15.00'
output: '60.00'
unit: '0.000001'
currency: USD

View File

@@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -119,7 +119,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
embeddings.append(result[0].get("embedding"))
return [list(map(float, e)) for e in embeddings]
elif "texts" == text_input_key:
elif text_input_key == "texts":
result = client.run(
replicate_model_version,
input={

View File

@@ -18,7 +18,7 @@ class SiliconflowProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials)
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:

View File

@@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
video_url = message_content.url
if not video_url:
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")
sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)

View File

@@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -1,4 +1,3 @@
import re
from collections.abc import Generator
from typing import Optional, cast
@@ -104,17 +103,16 @@ class ArkClientV3:
if message_content.type == PromptMessageContentType.TEXT:
content.append(
ChatCompletionContentPartTextParam(
text=message_content.text,
text=message_content.data,
type="text",
)
)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data)
content.append(
ChatCompletionContentPartImageParam(
image_url=ImageURL(
url=image_data,
url=message_content.data,
detail=message_content.detail.value,
),
type="image_url",

View File

@@ -132,6 +132,14 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages]
for message in messages_dict:
for key, value in message.items():
# Ignore tokens for image type
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
num_tokens += self._get_num_tokens_by_gpt2(str(key))
num_tokens += self._get_num_tokens_by_gpt2(str(value))

View File

@@ -16,6 +16,14 @@ class ModelConfig(BaseModel):
configs: dict[str, ModelConfig] = {
"Doubao-vision-pro-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.VISION],
),
"Doubao-vision-lite-32k": ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.VISION],
),
"Doubao-pro-4k": ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
@@ -32,6 +40,10 @@ configs: dict[str, ModelConfig] = {
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
),
"Doubao-pro-256k": ModelConfig(
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
features=[],
),
"Doubao-pro-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],

View File

@@ -12,6 +12,7 @@ class ModelConfig(BaseModel):
ModelConfigs = {
"Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
"Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
}
@@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get("context_size", 0)),
context_size=int(credentials.get("context_size", 4096)),
max_chunks=int(credentials.get("max_chunks", 1)),
)
)

View File

@@ -118,6 +118,18 @@ model_credential_schema:
type: select
required: true
options:
- label:
en_US: Doubao-vision-pro-32k
value: Doubao-vision-pro-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-vision-lite-32k
value: Doubao-vision-lite-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-4k
value: Doubao-pro-4k
@@ -154,6 +166,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-256k
value: Doubao-pro-256k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Llama3-8B
value: Llama3-8B
@@ -208,6 +226,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: Doubao-embedding-large
value: Doubao-embedding-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: Custom
zh_Hans: 自定义

View File

@@ -49,10 +49,10 @@ class LindormVectorStoreConfig(BaseModel):
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs):
self._routing = None
self._routing_field = None
if config.using_ugc:
if using_ugc:
routing_value: str = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
@@ -64,7 +64,7 @@ class LindormVectorStore(BaseVector):
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = config.using_ugc
self._using_ugc = using_ugc
self.kwargs = kwargs
def get_type(self) -> str:
@@ -467,12 +467,16 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
using_ugc = dify_config.USING_UGC_INDEX
routing_value = None
if dataset.index_struct:
if using_ugc:
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
@@ -487,6 +491,7 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
"index_type": index_type,
"dimension": dimension,
"distance_type": distance_type,
"using_ugc": using_ugc,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
@@ -494,4 +499,4 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
routing_value = class_prefix
else:
index_name = class_prefix
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc)

View File

@@ -65,6 +65,11 @@ class CacheEmbedding(Embeddings):
for vector in embedding_result.embeddings:
try:
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning(f"Normalized embedding is nan: {normalized_embedding}")
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
@@ -111,6 +116,8 @@ class CacheEmbedding(Embeddings):
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")

View File

@@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
ws = websocket.WebSocket()
base_url = URL(credentials.get("base_url"))
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
ws_protocol = "ws"
if base_url.scheme == "https":
ws_protocol = "wss"
ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123"
try:
ws.connect(ws_address)

View File

@@ -40,7 +40,10 @@ class ComfyUiClient:
def open_websocket_connection(self) -> tuple[WebSocket, str]:
client_id = str(uuid.uuid4())
ws = WebSocket()
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
ws_protocol = "ws"
if self.base_url.scheme == "https":
ws_protocol = "wss"
ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}"
ws.connect(ws_address)
return ws, client_id

View File

@@ -43,6 +43,13 @@ class JinaReaderTool(BuiltinTool):
if wait_for_selector is not None and wait_for_selector != "":
headers["X-Wait-For-Selector"] = wait_for_selector
remove_selector = tool_parameters.get("remove_selector")
if remove_selector is not None and remove_selector != "":
headers["X-Remove-Selector"] = remove_selector
if tool_parameters.get("retain_images", False):
headers["X-Retain-Images"] = "true"
if tool_parameters.get("image_caption", False):
headers["X-With-Generated-Alt"] = "true"
@@ -59,6 +66,12 @@ class JinaReaderTool(BuiltinTool):
if tool_parameters.get("no_cache", False):
headers["X-No-Cache"] = "true"
if tool_parameters.get("with_iframe", False):
headers["X-With-Iframe"] = "true"
if tool_parameters.get("with_shadow_dom", False):
headers["X-With-Shadow-Dom"] = "true"
max_retries = tool_parameters.get("max_retries", 3)
response = ssrf_proxy.get(
str(URL(self._jina_reader_endpoint + url)),

View File

@@ -67,6 +67,33 @@ parameters:
pt_BR: css selector para aguardar elementos específicos
llm_description: css selector of the target element to wait for
form: form
- name: remove_selector
type: string
required: false
label:
en_US: Excluded Selector
zh_Hans: 排除选择器
pt_BR: Seletor Excluído
human_description:
en_US: css selector for remove for specific elements
zh_Hans: css 选择器用于排除特定元素
pt_BR: seletor CSS para remover elementos específicos
llm_description: css selector of the target element to remove for
form: form
- name: retain_images
type: boolean
required: false
default: false
label:
en_US: Remove All Images
zh_Hans: 删除所有图片
pt_BR: Remover todas as imagens
human_description:
en_US: Removes all images from the response.
zh_Hans: 从响应中删除所有图片。
pt_BR: Remove todas as imagens da resposta.
llm_description: Remove all images
form: form
- name: image_caption
type: boolean
required: false
@@ -136,6 +163,34 @@ parameters:
pt_BR: Ignorar o cache
llm_description: bypass the cache
form: form
- name: with_iframe
type: boolean
required: false
default: false
label:
en_US: Enable iframe extraction
zh_Hans: 启用 iframe 提取
pt_BR: Habilitar extração de iframe
human_description:
en_US: Extract and process content of all embedded iframes in the DOM tree.
zh_Hans: 提取并处理 DOM 树中所有嵌入 iframe 的内容。
pt_BR: Extrair e processar o conteúdo de todos os iframes incorporados na árvore DOM.
llm_description: Extract content from embedded iframes
form: form
- name: with_shadow_dom
type: boolean
required: false
default: false
label:
en_US: Enable Shadow DOM extraction
zh_Hans: 启用 Shadow DOM 提取
pt_BR: Habilitar extração de Shadow DOM
human_description:
en_US: Traverse all Shadow DOM roots in the document and extract content.
zh_Hans: 遍历文档中所有 Shadow DOM 根并提取内容。
pt_BR: Percorra todas as raízes do Shadow DOM no documento e extraia o conteúdo.
llm_description: Extract content from Shadow DOM roots
form: form
- name: summary
type: boolean
required: false

View File

@@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed
# single step node run retry
retry_index: int = 0

View File

@@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(BaseNodeEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
start_index: int = Field(..., description="retry start index")
###########################################
# Parallel Branch Events
###########################################

View File

@@ -5,6 +5,7 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -581,7 +583,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode,
node_instance: BaseNode[BaseNodeData],
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@@ -607,36 +609,121 @@ class GraphEngine:
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
shoudl_continue_retry = True
while shoudl_continue_retry and retries <= max_retries:
try:
# run node
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
retries += 1
self.graph_runtime_state.node_run_steps += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
error=run_result.error,
retry_index=retries,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
start_at=retry_start_at,
start_index=self.graph_runtime_state.node_run_steps,
)
time.sleep(retry_interval)
continue
route_node_state.set_finished(run_result=run_result)
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
@@ -645,21 +732,23 @@ class GraphEngine:
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
@@ -670,108 +759,59 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

View File

@@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
@property
def should_retry(self) -> bool:
"""judge if should retry
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

View File

@@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]

View File

@@ -1,4 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
@@ -6,5 +12,6 @@ __all__ = [
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

Some files were not shown because too many files have changed in this diff Show More