mirror of
https://github.com/langgenius/dify.git
synced 2026-04-12 23:49:27 +08:00
Compare commits
58 Commits
0.14.0
...
feat/node-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fde3fe0ab6 | ||
|
|
07528f82b9 | ||
|
|
bb2f46d7cc | ||
|
|
463fbe2680 | ||
|
|
95a7e50137 | ||
|
|
9d93ad1f16 | ||
|
|
44104797d6 | ||
|
|
1548501050 | ||
|
|
de3911e930 | ||
|
|
5a8a901560 | ||
|
|
12d45e9114 | ||
|
|
d057067543 | ||
|
|
560d375e0f | ||
|
|
127291a90f | ||
|
|
9e0c28791d | ||
|
|
3388d6636c | ||
|
|
2624a6dcd0 | ||
|
|
b5c2785e10 | ||
|
|
493834d45d | ||
|
|
926546b153 | ||
|
|
b411087bb7 | ||
|
|
357769c72e | ||
|
|
56434db4f5 | ||
|
|
688292e6ff | ||
|
|
f7415e1ca4 | ||
|
|
2961fa0e08 | ||
|
|
ad17ff9a92 | ||
|
|
558ab25f51 | ||
|
|
a5db7c9acb | ||
|
|
580297e290 | ||
|
|
853b9af09c | ||
|
|
79d11ea709 | ||
|
|
99f40a9682 | ||
|
|
e86756cb39 | ||
|
|
1325246da8 | ||
|
|
dfa9a91906 | ||
|
|
5e2926a207 | ||
|
|
9048832a9a | ||
|
|
7d5a385811 | ||
|
|
b99f1a09f4 | ||
|
|
900e93f758 | ||
|
|
99430a5931 | ||
|
|
c9b4029ce7 | ||
|
|
78c3051585 | ||
|
|
cd4310df25 | ||
|
|
259cff9f22 | ||
|
|
7b7eb00385 | ||
|
|
62b9e5a6f9 | ||
|
|
a399502ecd | ||
|
|
92a840f1b2 | ||
|
|
74fdc16bd1 | ||
|
|
56cfdce453 | ||
|
|
efa8eb379f | ||
|
|
7f095bdc42 | ||
|
|
e20161b3de | ||
|
|
fc8fdbacb4 | ||
|
|
7fde638556 | ||
|
|
be93c19b7e |
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@@ -50,6 +50,9 @@ jobs:
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@@ -9,5 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
|
||||
@@ -60,17 +60,8 @@ DB_DATABASE=dify
|
||||
STORAGE_TYPE=opendal
|
||||
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
STORAGE_OPENDAL_SCHEME=fs
|
||||
# OpenDAL FS
|
||||
OPENDAL_SCHEME=fs
|
||||
OPENDAL_FS_ROOT=storage
|
||||
# OpenDAL S3
|
||||
OPENDAL_S3_ROOT=/
|
||||
OPENDAL_S3_BUCKET=your-bucket-name
|
||||
OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com
|
||||
OPENDAL_S3_ACCESS_KEY_ID=your-access-key
|
||||
OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key
|
||||
OPENDAL_S3_REGION=your-region
|
||||
OPENDAL_S3_SERVER_SIDE_ENCRYPTION=
|
||||
|
||||
# S3 Storage configuration
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
@@ -313,8 +304,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
@@ -435,3 +425,5 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
@@ -70,7 +70,6 @@ ignore = [
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
27
api/app.py
27
api/app.py
@@ -1,13 +1,30 @@
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils, version_utils
|
||||
from libs import version_utils
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
|
||||
def is_db_command():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils
|
||||
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
@@ -17,15 +16,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
if isinstance(value, str):
|
||||
os.environ[key] = value
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
@@ -98,3 +88,14 @@ def initialize_extensions(app: DifyApp):
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
|
||||
|
||||
def create_migrations_app():
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@@ -485,6 +485,11 @@ class AuthConfig(BaseSettings):
|
||||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@@ -660,14 +665,9 @@ class IndexingConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class VisionFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
@@ -773,13 +773,13 @@ class FeatureConfig(
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
VisionFormatConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
LoggingConfig,
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
|
||||
@@ -1,51 +1,9 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenDALScheme(StrEnum):
|
||||
FS = "fs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
class OpenDALStorageConfig(BaseSettings):
|
||||
STORAGE_OPENDAL_SCHEME: str = Field(
|
||||
default=OpenDALScheme.FS.value,
|
||||
OPENDAL_SCHEME: str = Field(
|
||||
default="fs",
|
||||
description="OpenDAL scheme.",
|
||||
)
|
||||
# FS
|
||||
OPENDAL_FS_ROOT: str = Field(
|
||||
default="storage",
|
||||
description="Root path for local storage.",
|
||||
)
|
||||
# S3
|
||||
OPENDAL_S3_ROOT: str = Field(
|
||||
default="/",
|
||||
description="Root path for S3 storage.",
|
||||
)
|
||||
OPENDAL_S3_BUCKET: str = Field(
|
||||
default="",
|
||||
description="S3 bucket name.",
|
||||
)
|
||||
OPENDAL_S3_ENDPOINT: str = Field(
|
||||
default="https://s3.amazonaws.com",
|
||||
description="S3 endpoint URL.",
|
||||
)
|
||||
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
|
||||
default="",
|
||||
description="S3 access key ID.",
|
||||
)
|
||||
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
|
||||
default="",
|
||||
description="S3 secret access key.",
|
||||
)
|
||||
OPENDAL_S3_REGION: str = Field(
|
||||
default="",
|
||||
description="S3 region.",
|
||||
)
|
||||
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
|
||||
default="",
|
||||
description="S3 server-side encryption.",
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.0",
|
||||
default="0.14.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -65,7 +65,7 @@ class ModelConfigResource(Resource):
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
@@ -97,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
@@ -26,7 +27,7 @@ class TraceAppConfigApi(Resource):
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -48,7 +49,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigCheckError()
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -68,7 +69,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -85,7 +86,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
|
||||
@@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
if document.indexing_status == "completed":
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception(f"Failed to retry document, document id: {document_id}")
|
||||
continue
|
||||
# retry document
|
||||
|
||||
@@ -4,6 +4,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
@@ -12,6 +13,8 @@ app_fields = {
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -58,6 +59,9 @@ class FileApi(Resource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source == "datasets" and not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.error,
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
start_index=event.start_index,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
|
||||
@@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(AppQueueEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
start_index: int # start index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
|
||||
@@ -52,6 +52,7 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
@@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class NodeRetryStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
@@ -271,9 +273,9 @@ class WorkflowCycleManage:
|
||||
|
||||
db.session.close()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
session.add(workflow_run)
|
||||
session.refresh(workflow_run)
|
||||
# with Session(db.engine, expire_on_commit=False) as session:
|
||||
# session.add(workflow_run)
|
||||
# session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -423,6 +425,52 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.index = event.start_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
@@ -587,6 +635,51 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
|
||||
@@ -42,39 +42,31 @@ def to_prompt_message_content(
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type](**params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
@@ -128,11 +120,6 @@ def _get_encoded_string(f: File, /):
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
|
||||
@@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
class VideoPromptMessageContent(PromptMessageContent):
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(..., description="the format of multi-modal file")
|
||||
base64_data: str = Field("", description="the base64 data of multi-modal file")
|
||||
url: str = Field("", description="the url of multi-modal file")
|
||||
mime_type: str = Field(..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
data: str = Field(..., description="Base64 encoded video data")
|
||||
format: str = Field(..., description="Video format")
|
||||
|
||||
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
format: str = Field(..., description="Audio format")
|
||||
|
||||
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
@@ -103,11 +116,8 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
@@ -18,7 +17,6 @@ from anthropic.types import (
|
||||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities import (
|
||||
@@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
if not message_content.base64_data:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
image_content = requests.get(message_content.url).content
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||
)
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
base64_data = message_content.base64_data
|
||||
|
||||
mime_type = message_content.mime_type
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
@@ -534,7 +529,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
},
|
||||
|
||||
@@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4o-2024-11-20",
|
||||
entity=AIModelEntity(
|
||||
model="fake-deployment-name",
|
||||
label=I18nObject(
|
||||
en_US="fake-deployment-name-label",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="presence_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
type="string",
|
||||
help=I18nObject(
|
||||
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||
),
|
||||
required=False,
|
||||
options=["text", "json_object", "json_schema"],
|
||||
),
|
||||
ParameterRule(
|
||||
name="json_schema",
|
||||
label=I18nObject(en_US="JSON Schema"),
|
||||
type="text",
|
||||
help=I18nObject(
|
||||
zh_Hans="设置返回的json schema,llm将按照它返回",
|
||||
en_US="Set a response json schema will ensure LLM to adhere it.",
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency="USD",
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4-turbo",
|
||||
entity=AIModelEntity(
|
||||
|
||||
@@ -86,6 +86,9 @@ model_credential_schema:
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
- label:
|
||||
en_US: '2024-10-21'
|
||||
value: '2024-10-21'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
@@ -168,6 +171,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-11-20
|
||||
value: gpt-4o-2024-11-20
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
||||
@@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
|
||||
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
region_name = credentials.get("aws_region")
|
||||
if not region_name:
|
||||
raise InvokeBadRequestError("aws_region is required")
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
def get_bedrock_client(service_name, credentials=None):
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
aws_access_key_id = credentials["aws_access_key_id"]
|
||||
aws_secret_access_key = credentials["aws_secret_access_key"]
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
|
||||
@@ -62,7 +62,10 @@ class BedrockRerankModel(RerankModel):
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
region = credentials.get("aws_region")
|
||||
# region is a required field
|
||||
if not region:
|
||||
raise InvokeBadRequestError("aws_region is required in credentials")
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
|
||||
@@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
model: InternVL2-8B
|
||||
label:
|
||||
en_US: InternVL2-8B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -0,0 +1,93 @@
|
||||
model: InternVL2.5-26B
|
||||
label:
|
||||
en_US: InternVL2.5-26B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -6,3 +6,5 @@
|
||||
- deepseek-coder-33B-instruct-chat
|
||||
- deepseek-coder-33B-instruct-completions
|
||||
- codegeex4-all-9b
|
||||
- InternVL2.5-26B
|
||||
- InternVL2-8B
|
||||
|
||||
@@ -29,18 +29,26 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model, model_parameters)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._invoke(
|
||||
GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model),
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
stop,
|
||||
stream,
|
||||
user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None:
|
||||
if model is None:
|
||||
model = "Qwen2-72B-Instruct"
|
||||
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
|
||||
if model.endswith("completions"):
|
||||
credentials["mode"] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
@@ -11,6 +13,8 @@
|
||||
- gemini-1.5-flash-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0924
|
||||
- gemini-exp-1206
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1114
|
||||
- gemini-pro
|
||||
- gemini-pro-vision
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
||||
@@ -7,6 +7,9 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
||||
@@ -7,6 +7,9 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -7,6 +7,9 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.client import _ClientManager
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
||||
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
@@ -35,21 +34,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
||||
"application/pdf",
|
||||
"application/x-javascript",
|
||||
"text/javascript",
|
||||
"application/x-python",
|
||||
"text/x-python",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/md",
|
||||
"text/csv",
|
||||
"text/xml",
|
||||
"text/rtf",
|
||||
]
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
@@ -201,29 +186,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-pro-vision":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
# Create a new ClientManager with tenant's API key
|
||||
new_client_manager = _ClientManager()
|
||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
@@ -317,8 +290,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
completion_tokens = response.usage_metadata.candidates_token_count
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
@@ -346,7 +323,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
@@ -359,6 +336,40 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
||||
key = f"{message_content.type.value}:{hash(message_content.data)}"
|
||||
if redis_client.exists(key):
|
||||
try:
|
||||
return genai.get_file(redis_client.get(key).decode())
|
||||
except:
|
||||
pass
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
if message_content.base64_data:
|
||||
file_content = base64.b64decode(message_content.base64_data)
|
||||
temp_file.write(file_content)
|
||||
else:
|
||||
try:
|
||||
response = requests.get(message_content.url)
|
||||
response.raise_for_status()
|
||||
temp_file.write(response.content)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}")
|
||||
temp_file.flush()
|
||||
|
||||
file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type)
|
||||
while file.state.name == "PROCESSING":
|
||||
time.sleep(5)
|
||||
file = genai.get_file(file.name)
|
||||
# google will delete your upload files in 2 days.
|
||||
redis_client.setex(key, 47 * 60 * 60, file.name)
|
||||
|
||||
try:
|
||||
os.unlink(temp_file.name)
|
||||
except PermissionError:
|
||||
# windows may raise permission error
|
||||
pass
|
||||
return file
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
@@ -374,28 +385,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
||||
message_content = cast(DocumentPromptMessageContent, c)
|
||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
||||
glm_content["parts"].append(blob)
|
||||
else:
|
||||
glm_content["parts"].append(self._upload_file_content_to_google(c))
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
||||
@@ -3,8 +3,8 @@ label:
|
||||
zh_Hans: 腾讯混元
|
||||
en_US: Hunyuan
|
||||
description:
|
||||
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro and hunyuan-lite.
|
||||
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro 和 hunyuan-lite。
|
||||
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall and hunyuan-lite.
|
||||
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall 和 hunyuan-lite。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
|
||||
@@ -4,3 +4,10 @@
|
||||
- hunyuan-pro
|
||||
- hunyuan-turbo
|
||||
- hunyuan-vision
|
||||
- hunyuan-role
|
||||
- hunyuan-large
|
||||
- hunyuan-large-role
|
||||
- hunyuan-large-longcontext
|
||||
- hunyuan-turbo-latest
|
||||
- hunyuan-turbo-vision
|
||||
- hunyuan-functioncall
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
model: hunyuan-functioncall
|
||||
label:
|
||||
zh_Hans: hunyuan-functioncall
|
||||
en_US: hunyuan-functioncall
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,38 @@
|
||||
model: hunyuan-large-longcontext
|
||||
label:
|
||||
zh_Hans: hunyuan-large-longcontext
|
||||
en_US: hunyuan-large-longcontext
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 134000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 134000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.006'
|
||||
output: '0.018'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,38 @@
|
||||
model: hunyuan-large-role
|
||||
label:
|
||||
zh_Hans: hunyuan-large-role
|
||||
en_US: hunyuan-large-role
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,38 @@
|
||||
model: hunyuan-large
|
||||
label:
|
||||
zh_Hans: hunyuan-large
|
||||
en_US: hunyuan-large
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.012'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,38 @@
|
||||
model: hunyuan-role
|
||||
label:
|
||||
zh_Hans: hunyuan-role
|
||||
en_US: hunyuan-role
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,38 @@
|
||||
model: hunyuan-turbo-latest
|
||||
label:
|
||||
zh_Hans: hunyuan-turbo-latest
|
||||
en_US: hunyuan-turbo-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.05'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,39 @@
|
||||
model: hunyuan-turbo-vision
|
||||
label:
|
||||
zh_Hans: hunyuan-turbo-vision
|
||||
en_US: hunyuan-turbo-vision
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.08'
|
||||
output: '0.08'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,4 +1,7 @@
|
||||
- gpt-4o-audio-preview
|
||||
- o1
|
||||
- o1-2024-12-17
|
||||
- o1-mini
|
||||
- o1-mini-2024-09-12
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
@@ -7,10 +10,6 @@
|
||||
- chatgpt-4o-latest
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- o1-preview
|
||||
- o1-preview-2024-09-12
|
||||
- o1-mini
|
||||
- o1-mini-2024-09-12
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
@@ -25,4 +24,7 @@
|
||||
- gpt-3.5-turbo-1106
|
||||
- gpt-3.5-turbo-0613
|
||||
- gpt-3.5-turbo-instruct
|
||||
- gpt-4o-audio-preview
|
||||
- o1-preview
|
||||
- o1-preview-2024-09-12
|
||||
- text-davinci-003
|
||||
|
||||
@@ -22,7 +22,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
||||
@@ -22,9 +22,9 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@@ -22,7 +22,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
||||
@@ -22,7 +22,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
||||
@@ -22,9 +22,9 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@@ -22,7 +22,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
||||
@@ -22,7 +22,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
||||
@@ -22,9 +22,9 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
@@ -38,7 +38,7 @@ parameter_rules:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
||||
@@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, AudioPromptMessageContent):
|
||||
data_split = message_content.data.split(";base64,")
|
||||
base64_data = data_split[1]
|
||||
sub_message_dict = {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": message_content.data,
|
||||
"data": base64_data,
|
||||
"format": message_content.format,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
model: o1-2024-12-17
|
||||
label:
|
||||
en_US: o1-2024-12-17
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 50000
|
||||
min: 1
|
||||
max: 50000
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '15.00'
|
||||
output: '60.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
36
api/core/model_runtime/model_providers/openai/llm/o1.yaml
Normal file
36
api/core/model_runtime/model_providers/openai/llm/o1.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
model: o1
|
||||
label:
|
||||
zh_Hans: o1
|
||||
en_US: o1
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 50000
|
||||
min: 1
|
||||
max: 50000
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '15.00'
|
||||
output: '60.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -119,7 +119,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
embeddings.append(result[0].get("embedding"))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
elif "texts" == text_input_key:
|
||||
elif text_input_key == "texts":
|
||||
result = client.run(
|
||||
replicate_model_version,
|
||||
input={
|
||||
|
||||
@@ -18,7 +18,7 @@ class SiliconflowProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials)
|
||||
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.VIDEO:
|
||||
message_content = cast(VideoPromptMessageContent, message_content)
|
||||
video_url = message_content.data
|
||||
if message_content.data.startswith("data:"):
|
||||
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
|
||||
video_url = message_content.url
|
||||
if not video_url:
|
||||
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")
|
||||
|
||||
sub_message_dict = {"video": video_url}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
@@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
@@ -104,17 +103,16 @@ class ArkClientV3:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
content.append(
|
||||
ChatCompletionContentPartTextParam(
|
||||
text=message_content.text,
|
||||
text=message_content.data,
|
||||
type="text",
|
||||
)
|
||||
)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data)
|
||||
content.append(
|
||||
ChatCompletionContentPartImageParam(
|
||||
image_url=ImageURL(
|
||||
url=image_data,
|
||||
url=message_content.data,
|
||||
detail=message_content.detail.value,
|
||||
),
|
||||
type="image_url",
|
||||
|
||||
@@ -132,6 +132,14 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
for key, value in message.items():
|
||||
# Ignore tokens for image type
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||
|
||||
|
||||
@@ -16,6 +16,14 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
configs: dict[str, ModelConfig] = {
|
||||
"Doubao-vision-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-vision-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-pro-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
@@ -32,6 +40,10 @@ configs: dict[str, ModelConfig] = {
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[],
|
||||
),
|
||||
"Doubao-pro-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
|
||||
@@ -12,6 +12,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
ModelConfigs = {
|
||||
"Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
|
||||
"Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get("context_size", 0)),
|
||||
context_size=int(credentials.get("context_size", 4096)),
|
||||
max_chunks=int(credentials.get("max_chunks", 1)),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -118,6 +118,18 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: Doubao-vision-pro-32k
|
||||
value: Doubao-vision-pro-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-vision-lite-32k
|
||||
value: Doubao-vision-lite-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-pro-4k
|
||||
value: Doubao-pro-4k
|
||||
@@ -154,6 +166,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-pro-256k
|
||||
value: Doubao-pro-256k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Llama3-8B
|
||||
value: Llama3-8B
|
||||
@@ -208,6 +226,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Doubao-embedding-large
|
||||
value: Doubao-embedding-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Custom
|
||||
zh_Hans: 自定义
|
||||
|
||||
@@ -49,10 +49,10 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
|
||||
|
||||
class LindormVectorStore(BaseVector):
|
||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs):
|
||||
self._routing = None
|
||||
self._routing_field = None
|
||||
if config.using_ugc:
|
||||
if using_ugc:
|
||||
routing_value: str = kwargs.get("routing_value")
|
||||
if routing_value is None:
|
||||
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||
@@ -64,7 +64,7 @@ class LindormVectorStore(BaseVector):
|
||||
super().__init__(collection_name.lower())
|
||||
self._client_config = config
|
||||
self._client = OpenSearch(**config.to_opensearch_params())
|
||||
self._using_ugc = config.using_ugc
|
||||
self._using_ugc = using_ugc
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_type(self) -> str:
|
||||
@@ -467,12 +467,16 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
using_ugc = dify_config.USING_UGC_INDEX
|
||||
routing_value = None
|
||||
if dataset.index_struct:
|
||||
if using_ugc:
|
||||
# if an existed record's index_struct_dict doesn't contain using_ugc field,
|
||||
# it actually stores in the normal index format
|
||||
stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
|
||||
using_ugc = stored_in_ugc
|
||||
if stored_in_ugc:
|
||||
dimension = dataset.index_struct_dict["dimension"]
|
||||
index_type = dataset.index_struct_dict["index_type"]
|
||||
distance_type = dataset.index_struct_dict["distance_type"]
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
||||
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
||||
else:
|
||||
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
else:
|
||||
@@ -487,6 +491,7 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
"index_type": index_type,
|
||||
"dimension": dimension,
|
||||
"distance_type": distance_type,
|
||||
"using_ugc": using_ugc,
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
if using_ugc:
|
||||
@@ -494,4 +499,4 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
routing_value = class_prefix
|
||||
else:
|
||||
index_name = class_prefix
|
||||
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)
|
||||
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc)
|
||||
|
||||
@@ -65,6 +65,11 @@ class CacheEmbedding(Embeddings):
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||
if np.isnan(normalized_embedding).any():
|
||||
# for issue #11827 float values are not json compliant
|
||||
logger.warning(f"Normalized embedding is nan: {normalized_embedding}")
|
||||
continue
|
||||
embedding_queue_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
@@ -111,6 +116,8 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
if np.isnan(embedding_results).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
|
||||
|
||||
@@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
ws = websocket.WebSocket()
|
||||
base_url = URL(credentials.get("base_url"))
|
||||
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
|
||||
ws_protocol = "ws"
|
||||
if base_url.scheme == "https":
|
||||
ws_protocol = "wss"
|
||||
ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123"
|
||||
|
||||
try:
|
||||
ws.connect(ws_address)
|
||||
|
||||
@@ -40,7 +40,10 @@ class ComfyUiClient:
|
||||
def open_websocket_connection(self) -> tuple[WebSocket, str]:
|
||||
client_id = str(uuid.uuid4())
|
||||
ws = WebSocket()
|
||||
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws_protocol = "ws"
|
||||
if self.base_url.scheme == "https":
|
||||
ws_protocol = "wss"
|
||||
ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws.connect(ws_address)
|
||||
return ws, client_id
|
||||
|
||||
|
||||
@@ -43,6 +43,13 @@ class JinaReaderTool(BuiltinTool):
|
||||
if wait_for_selector is not None and wait_for_selector != "":
|
||||
headers["X-Wait-For-Selector"] = wait_for_selector
|
||||
|
||||
remove_selector = tool_parameters.get("remove_selector")
|
||||
if remove_selector is not None and remove_selector != "":
|
||||
headers["X-Remove-Selector"] = remove_selector
|
||||
|
||||
if tool_parameters.get("retain_images", False):
|
||||
headers["X-Retain-Images"] = "true"
|
||||
|
||||
if tool_parameters.get("image_caption", False):
|
||||
headers["X-With-Generated-Alt"] = "true"
|
||||
|
||||
@@ -59,6 +66,12 @@ class JinaReaderTool(BuiltinTool):
|
||||
if tool_parameters.get("no_cache", False):
|
||||
headers["X-No-Cache"] = "true"
|
||||
|
||||
if tool_parameters.get("with_iframe", False):
|
||||
headers["X-With-Iframe"] = "true"
|
||||
|
||||
if tool_parameters.get("with_shadow_dom", False):
|
||||
headers["X-With-Shadow-Dom"] = "true"
|
||||
|
||||
max_retries = tool_parameters.get("max_retries", 3)
|
||||
response = ssrf_proxy.get(
|
||||
str(URL(self._jina_reader_endpoint + url)),
|
||||
|
||||
@@ -67,6 +67,33 @@ parameters:
|
||||
pt_BR: css selector para aguardar elementos específicos
|
||||
llm_description: css selector of the target element to wait for
|
||||
form: form
|
||||
- name: remove_selector
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Excluded Selector
|
||||
zh_Hans: 排除选择器
|
||||
pt_BR: Seletor Excluído
|
||||
human_description:
|
||||
en_US: css selector for remove for specific elements
|
||||
zh_Hans: css 选择器用于排除特定元素
|
||||
pt_BR: seletor CSS para remover elementos específicos
|
||||
llm_description: css selector of the target element to remove for
|
||||
form: form
|
||||
- name: retain_images
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Remove All Images
|
||||
zh_Hans: 删除所有图片
|
||||
pt_BR: Remover todas as imagens
|
||||
human_description:
|
||||
en_US: Removes all images from the response.
|
||||
zh_Hans: 从响应中删除所有图片。
|
||||
pt_BR: Remove todas as imagens da resposta.
|
||||
llm_description: Remove all images
|
||||
form: form
|
||||
- name: image_caption
|
||||
type: boolean
|
||||
required: false
|
||||
@@ -136,6 +163,34 @@ parameters:
|
||||
pt_BR: Ignorar o cache
|
||||
llm_description: bypass the cache
|
||||
form: form
|
||||
- name: with_iframe
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Enable iframe extraction
|
||||
zh_Hans: 启用 iframe 提取
|
||||
pt_BR: Habilitar extração de iframe
|
||||
human_description:
|
||||
en_US: Extract and process content of all embedded iframes in the DOM tree.
|
||||
zh_Hans: 提取并处理 DOM 树中所有嵌入 iframe 的内容。
|
||||
pt_BR: Extrair e processar o conteúdo de todos os iframes incorporados na árvore DOM.
|
||||
llm_description: Extract content from embedded iframes
|
||||
form: form
|
||||
- name: with_shadow_dom
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
label:
|
||||
en_US: Enable Shadow DOM extraction
|
||||
zh_Hans: 启用 Shadow DOM 提取
|
||||
pt_BR: Habilitar extração de Shadow DOM
|
||||
human_description:
|
||||
en_US: Traverse all Shadow DOM roots in the document and extract content.
|
||||
zh_Hans: 遍历文档中所有 Shadow DOM 根并提取内容。
|
||||
pt_BR: Percorra todas as raízes do Shadow DOM no documento e extraia o conteúdo.
|
||||
llm_description: Extract content from Shadow DOM roots
|
||||
form: form
|
||||
- name: summary
|
||||
type: boolean
|
||||
required: false
|
||||
|
||||
@@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
@@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
start_index: int = Field(..., description="retry start index")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -581,7 +583,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@@ -607,36 +609,121 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
shoudl_continue_retry = True
|
||||
while shoudl_continue_retry and retries <= max_retries:
|
||||
try:
|
||||
# run node
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
retries += 1
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
error=run_result.error,
|
||||
retry_index=retries,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
start_at=retry_start_at,
|
||||
start_index=self.graph_runtime_state.node_run_steps,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
@@ -645,21 +732,23 @@ class GraphEngine:
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
@@ -670,108 +759,59 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
|
||||
@@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run")
|
||||
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
@@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
|
||||
@@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from .event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
@@ -6,5 +12,6 @@ __all__ = [
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"RunStreamChunkEvent",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user