mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 12:14:34 +08:00
Compare commits
114 Commits
build/trig
...
mysql-adap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f586c44d8 | ||
|
|
5322f3bbd4 | ||
|
|
6433ac8209 | ||
|
|
84935b9169 | ||
|
|
eceaea68b1 | ||
|
|
26cfccb84b | ||
|
|
3042b69e77 | ||
|
|
49d5637b3c | ||
|
|
20403c69b2 | ||
|
|
ffc04f2a9b | ||
|
|
d1580791e4 | ||
|
|
c74eb4fcf3 | ||
|
|
a798534337 | ||
|
|
470883858e | ||
|
|
4f4911686d | ||
|
|
6d479dcdbb | ||
|
|
24348c40a6 | ||
|
|
a39b50adbb | ||
|
|
81832c14ee | ||
|
|
b86022c64a | ||
|
|
45e816a9f6 | ||
|
|
667b1c37a3 | ||
|
|
b75d533f9b | ||
|
|
aece55d82f | ||
|
|
c432b398f4 | ||
|
|
9cb2645793 | ||
|
|
6ac61bd585 | ||
|
|
b02165ffe6 | ||
|
|
6c576e2c66 | ||
|
|
b0e7e7752f | ||
|
|
2799b79e8c | ||
|
|
805a1479f9 | ||
|
|
fe6538b08d | ||
|
|
1bbb9d6644 | ||
|
|
5c06e285ec | ||
|
|
19c92fd670 | ||
|
|
6026bd873b | ||
|
|
1369119a0c | ||
|
|
b76e17b25d | ||
|
|
ca7794305b | ||
|
|
fd255e81e1 | ||
|
|
09d31d1263 | ||
|
|
47dc26f011 | ||
|
|
123bb3ec08 | ||
|
|
90f77282e3 | ||
|
|
5208867ccc | ||
|
|
edc7ccc795 | ||
|
|
c9798f6425 | ||
|
|
20ecf7f1d0 | ||
|
|
9dcb780fcb | ||
|
|
1cb7b09933 | ||
|
|
2c62a77cf4 | ||
|
|
b9bc48d8dd | ||
|
|
ed234e311b | ||
|
|
9843fec393 | ||
|
|
aa4cabdeb5 | ||
|
|
eea713b668 | ||
|
|
fc62538a94 | ||
|
|
7994144df7 | ||
|
|
e153c483b6 | ||
|
|
422bb4d4bb | ||
|
|
87a80d7613 | ||
|
|
e91105ca87 | ||
|
|
37903722fe | ||
|
|
f4c82d0010 | ||
|
|
fe50093c18 | ||
|
|
4317af1e90 | ||
|
|
61a0fcc2ea | ||
|
|
f627348b11 | ||
|
|
87fb9a6b69 | ||
|
|
97a2e2ec2e | ||
|
|
68d357d7f6 | ||
|
|
a103ad3ee7 | ||
|
|
f65d5a9761 | ||
|
|
6e0a5f5bbd | ||
|
|
22f858152f | ||
|
|
775d2e14fc | ||
|
|
744b287e67 | ||
|
|
c0fc5d98f0 | ||
|
|
08ea79d730 | ||
|
|
f31b821cc0 | ||
|
|
34be16874f | ||
|
|
e9738b891f | ||
|
|
829796514a | ||
|
|
ef1db35f80 | ||
|
|
f9c67621ca | ||
|
|
e29e8e3180 | ||
|
|
7a81e720d4 | ||
|
|
55600c0eb1 | ||
|
|
35e41d7d68 | ||
|
|
b610cf9a11 | ||
|
|
c8e9edc024 | ||
|
|
471cd760d7 | ||
|
|
7f48c57edf | ||
|
|
6569801162 | ||
|
|
9dd83f50a7 | ||
|
|
59c56b1b0d | ||
|
|
94cd2de940 | ||
|
|
3c23375607 | ||
|
|
56047f638f | ||
|
|
9c01d3e775 | ||
|
|
c85c87f3da | ||
|
|
eaa02e3d55 | ||
|
|
0219222a60 | ||
|
|
dba659b220 | ||
|
|
ee6458768e | ||
|
|
ed3d02dc6d | ||
|
|
95471b1188 | ||
|
|
6190cfbfd8 | ||
|
|
11f2f95103 | ||
|
|
2abbc14703 | ||
|
|
b2b2816ade | ||
|
|
4461df1bd9 | ||
|
|
f7f6b4a8b0 |
@@ -1,17 +1,15 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
61
.github/workflows/db-migration-test.yml
vendored
61
.github/workflows/db-migration-test.yml
vendored
@@ -8,7 +8,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
db-migration-test:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
|
||||
- name: Prepare configs
|
||||
@@ -57,3 +57,60 @@ jobs:
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_mysql
|
||||
redis
|
||||
|
||||
- name: Prepare configs for MySQL
|
||||
run: |
|
||||
cd api
|
||||
cp .env.example .env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -186,6 +186,8 @@ docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
|
||||
@@ -117,7 +117,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
@@ -72,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
# PostgreSQL database configuration
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
@@ -164,7 +167,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@@ -175,6 +178,17 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
OCEANBASE_FULLTEXT_PARSER=ik
|
||||
SEEKDB_MEMORY_LIMIT=2G
|
||||
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
@@ -340,15 +354,6 @@ LINDORM_PASSWORD=admin
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
@@ -374,6 +379,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@@ -521,7 +532,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
# Workflow log cleanup configuration
|
||||
# Enable automatic cleanup of workflow run logs to manage database size
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=true
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
# Number of days to retain workflow run logs (default: 30 days)
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
@@ -620,3 +631,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
@@ -15,7 +15,11 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@@ -49,7 +53,9 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command():
|
||||
def is_db_command() -> bool:
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1471,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
@@ -1482,9 +1485,14 @@ def transform_datasource_credentials():
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
@@ -1650,7 +1658,7 @@ def transform_datasource_credentials():
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
|
||||
@@ -360,6 +360,31 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
@@ -949,6 +974,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
@@ -1160,7 +1190,7 @@ class AccountConfig(BaseSettings):
|
||||
|
||||
|
||||
class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
@@ -1179,6 +1209,13 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -1205,6 +1242,7 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
|
||||
@@ -105,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
@@ -140,10 +146,10 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="Database URI scheme for SQLAlchemy connection.",
|
||||
default="postgresql",
|
||||
)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
@@ -204,15 +210,15 @@ class DatabaseConfig(BaseSettings):
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
connect_args = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
||||
@@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
|
||||
7343
api/constants/pipeline_templates.json
Normal file
7343
api/constants/pipeline_templates.json
Normal file
File diff suppressed because one or more lines are too long
@@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@@ -25,13 +27,6 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args")
|
||||
.add_argument("model_mode", type=str, required=True, location="args")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
.add_argument("model_name", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
@@ -8,17 +8,19 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@@ -27,12 +29,6 @@ class AgentLogApi(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
@@ -16,6 +16,7 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -175,8 +176,10 @@ class AnnotationApi(Resource):
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
@@ -193,11 +196,14 @@ class AnnotationApi(Resource):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@@ -245,6 +251,13 @@ class AnnotationExportApi(Resource):
|
||||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.doc("update_delete_annotation")
|
||||
@@ -253,6 +266,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -262,11 +276,6 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@@ -15,11 +15,12 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App
|
||||
from models import App, Workflow
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@@ -106,6 +107,35 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
@api.doc("create_app")
|
||||
@@ -353,12 +383,15 @@ class AppExportApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -367,7 +400,6 @@ class AppNameApi(Resource):
|
||||
@marshal_with(app_detail_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -18,9 +19,23 @@ from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -30,18 +45,6 @@ class AppImportApi(Resource):
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
@@ -19,7 +17,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
@@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
|
||||
@@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
@@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
@@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@marshal_with(annotation_fields)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
@@ -11,9 +9,10 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
from models import AppMode
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
@@ -45,8 +44,9 @@ class DailyMessageStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
@@ -56,26 +56,16 @@ WHERE
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -91,16 +81,19 @@ WHERE
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
@@ -113,48 +106,40 @@ class DailyConversationStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
sa.func.date(
|
||||
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
).label("date"),
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(stmt, {"tz": account.timezone})
|
||||
for row in rs:
|
||||
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
@@ -164,11 +149,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
@@ -181,15 +162,11 @@ class DailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
@@ -198,26 +175,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -238,11 +206,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
@@ -255,15 +219,11 @@ class DailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
@@ -273,26 +233,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -315,11 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
@@ -332,15 +279,11 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
@@ -357,26 +300,17 @@ FROM
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -408,11 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
@@ -425,15 +355,11 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
@@ -446,26 +372,17 @@ WHERE
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -491,11 +408,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
@@ -508,15 +421,11 @@ class AverageResponseTimeStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
@@ -525,26 +434,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -565,11 +465,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
@@ -581,16 +477,11 @@ class TokensPerSecondStatistic(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
@@ -602,26 +493,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
@@ -112,7 +113,18 @@ class DraftWorkflowApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
||||
@api.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
api.model(
|
||||
"SyncDraftWorkflowResponse",
|
||||
{
|
||||
"result": fields.String,
|
||||
"hash": fields.String,
|
||||
"updated_at": fields.String,
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid workflow configuration")
|
||||
@api.response(403, "Permission denied")
|
||||
@edit_permission_required
|
||||
@@ -574,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@api.doc("get_published_workflow")
|
||||
@@ -598,6 +617,7 @@ class PublishedWorkflowApi(Resource):
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@api.expect(parser_publish)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -608,12 +628,8 @@ class PublishedWorkflowApi(Resource):
|
||||
Publish workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
@@ -668,6 +684,9 @@ class DefaultBlockConfigsApi(Resource):
|
||||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@api.doc("get_default_block_config")
|
||||
@@ -675,6 +694,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@api.response(200, "Default block configuration retrieved successfully")
|
||||
@api.response(404, "Block type not found")
|
||||
@api.expect(parser_block)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -684,8 +704,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
@@ -701,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
|
||||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@api.expect(parser_convert)
|
||||
@api.doc("convert_to_workflow")
|
||||
@api.doc(description="Convert application to workflow mode")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@@ -723,14 +752,7 @@ class ConvertToWorkflowApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if request.data:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
@@ -744,8 +766,18 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.expect(parser_workflows)
|
||||
@api.doc("get_all_published_workflows")
|
||||
@api.doc(description="Get all published workflows for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@@ -762,16 +794,9 @@ class PublishedAllWorkflowApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
@@ -979,11 +1004,13 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
event = poller.poll()
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=event.workflow_args,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=node_id,
|
||||
@@ -992,7 +1019,7 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except PluginInvokeError as e:
|
||||
raise ValueError(e.to_user_friendly_error())
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
@@ -1050,7 +1077,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
||||
)
|
||||
event = poller.poll()
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 500
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
@@ -1074,7 +1101,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{"status": "error", "error": "An unexpected error occurred while running the node."}
|
||||
), 500
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
|
||||
@@ -1126,7 +1153,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
node_ids=node_ids,
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
raise ValueError(e.to_user_friendly_error())
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
@@ -1134,10 +1161,12 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=trigger_debug_event.workflow_args,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=trigger_debug_event.node_id,
|
||||
@@ -1151,4 +1180,4 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
), 400
|
||||
|
||||
@@ -30,23 +30,25 @@ def _parse_workflow_run_list_args():
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -58,28 +60,30 @@ def _parse_workflow_run_count_args():
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -9,6 +6,7 @@ from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -43,23 +41,11 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@@ -100,23 +86,11 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@@ -157,23 +131,11 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@@ -214,23 +176,11 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_date = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_date = end_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@@ -16,7 +17,13 @@ class Subscription(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -121,8 +121,16 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@api.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -130,14 +138,7 @@ class DatasourceAuth(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@@ -168,8 +169,14 @@ class DatasourceAuth(Resource):
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@api.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -181,10 +188,7 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource_delete.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@@ -195,8 +199,17 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@api.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -205,13 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource_update.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
@@ -251,8 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@api.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -260,12 +275,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource_custom.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
@@ -291,8 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -300,8 +314,7 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_default.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
@@ -312,8 +325,16 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@api.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -321,12 +342,7 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_update_name.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
@@ -148,8 +148,12 @@ class DraftRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@api.expect(parser_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -162,8 +166,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_run.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_iteration(
|
||||
@@ -184,6 +187,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@api.expect(parser_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -197,8 +201,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_run.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_loop(
|
||||
@@ -217,8 +220,18 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
parser_draft_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@api.expect(parser_draft_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -232,14 +245,7 @@ class DraftRagPipelineRunApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_draft_run.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
@@ -255,8 +261,21 @@ class DraftRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
parser_published_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@api.expect(parser_published_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -270,17 +289,7 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_published_run.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
@@ -381,8 +390,17 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
#
|
||||
# return result
|
||||
#
|
||||
parser_rag_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@api.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -396,13 +414,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
@@ -429,6 +441,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@api.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -442,13 +455,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
@@ -473,8 +480,14 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_run_api = reqparse.RequestParser().add_argument(
|
||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@api.expect(parser_run_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -489,10 +502,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_run_api.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
@@ -607,8 +617,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -622,8 +636,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_default.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
@@ -639,8 +652,18 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_wf = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@api.expect(parser_wf)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -654,16 +677,9 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
args = parser_wf.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
@@ -691,8 +707,16 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_wf_id = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@api.expect(parser_wf_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -707,19 +731,13 @@ class RagPipelineByIdApi(Resource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_id.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
@@ -752,8 +770,12 @@ class RagPipelineByIdApi(Resource):
|
||||
return workflow
|
||||
|
||||
|
||||
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -763,8 +785,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
@@ -777,6 +798,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -786,8 +808,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
@@ -800,6 +821,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -809,8 +831,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
@@ -823,6 +844,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -832,8 +854,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
@@ -845,8 +866,16 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_wf_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@api.expect(parser_wf_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -856,12 +885,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_run.parse_args()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
|
||||
@@ -961,8 +985,18 @@ class RagPipelineTransformApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
parser_var = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@api.expect(parser_var)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -974,14 +1008,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
Set datasource variables
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_var.parse_args()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import current_user, login_required
|
||||
@@ -35,15 +35,18 @@ recommended_app_list_fields = {
|
||||
}
|
||||
|
||||
|
||||
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@api.expect(parser_apps)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
def get(self):
|
||||
# language args
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_apps.parse_args()
|
||||
|
||||
language = args.get("language")
|
||||
if language and language in languages:
|
||||
|
||||
@@ -8,6 +8,7 @@ import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
@@ -83,6 +84,8 @@ class FileApi(Resource):
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
@@ -36,12 +37,15 @@ class RemoteFileInfoApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@api.expect(parser_upload)
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
args = parser_upload.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class SetupApi(Resource):
|
||||
"email": fields.String(required=True, description="Admin email address"),
|
||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
||||
"password": fields.String(required=True, description="Admin password"),
|
||||
"language": fields.String(required=False, description="Admin language"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -16,6 +16,19 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
parser_tags = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@@ -30,6 +43,7 @@ class TagListApi(Resource):
|
||||
|
||||
return tags, 200
|
||||
|
||||
@api.expect(parser_tags)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -39,20 +53,7 @@ class TagListApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_tags.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
@@ -60,8 +61,14 @@ class TagListApi(Resource):
|
||||
return response, 200
|
||||
|
||||
|
||||
parser_tag_id = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@api.expect(parser_tag_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -72,10 +79,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_tag_id.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
@@ -99,8 +103,17 @@ class TagUpdateDeleteApi(Resource):
|
||||
return 204
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@api.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -110,26 +123,23 @@ class TagBindingCreateApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
|
||||
)
|
||||
.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_create.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_remove = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@api.expect(parser_remove)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -139,15 +149,7 @@ class TagBindingDeleteApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_remove.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -11,16 +11,16 @@ from . import api, console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@api.doc("check_version_update")
|
||||
@api.doc(description="Check for application version updates")
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
)
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
@@ -37,7 +37,6 @@ class VersionApi(Resource):
|
||||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailChangeLimitError,
|
||||
@@ -43,8 +43,19 @@ from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
def _init_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
parser.add_argument("invitation_code", type=str, location="json")
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||
"timezone", type=timezone, required=True, location="json"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@api.expect(_init_parser())
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
@@ -53,14 +64,7 @@ class AccountInitApi(Resource):
|
||||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
parser.add_argument("invitation_code", type=str, location="json")
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||
"timezone", type=timezone, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = _init_parser().parse_args()
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
if not args["invitation_code"]:
|
||||
@@ -106,16 +110,19 @@ class AccountProfileApi(Resource):
|
||||
return current_user
|
||||
|
||||
|
||||
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@api.expect(parser_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_name.parse_args()
|
||||
|
||||
# Validate account name length
|
||||
if len(args["name"]) < 3 or len(args["name"]) > 30:
|
||||
@@ -126,68 +133,80 @@ class AccountNameApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@api.expect(parser_avatar)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_avatar.parse_args()
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_interface = reqparse.RequestParser().add_argument(
|
||||
"interface_language", type=supported_language, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@api.expect(parser_interface)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"interface_language", type=supported_language, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_interface.parse_args()
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_theme = reqparse.RequestParser().add_argument(
|
||||
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@api.expect(parser_theme)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_theme.parse_args()
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@api.expect(parser_timezone)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_timezone.parse_args()
|
||||
|
||||
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
||||
if args["timezone"] not in pytz.all_timezones:
|
||||
@@ -198,21 +217,24 @@ class AccountTimezoneApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_pw = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("password", type=str, required=False, location="json")
|
||||
.add_argument("new_password", type=str, required=True, location="json")
|
||||
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@api.expect(parser_pw)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("password", type=str, required=False, location="json")
|
||||
.add_argument("new_password", type=str, required=True, location="json")
|
||||
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_pw.parse_args()
|
||||
|
||||
if args["new_password"] != args["repeat_new_password"]:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
@@ -294,20 +316,23 @@ class AccountDeleteVerifyApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_delete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@api.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_delete.parse_args()
|
||||
|
||||
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
||||
raise InvalidAccountDeletionCodeError()
|
||||
@@ -317,16 +342,19 @@ class AccountDeleteApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_feedback = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("feedback", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@api.expect(parser_feedback)
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("feedback", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_feedback.parse_args()
|
||||
|
||||
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
||||
|
||||
@@ -351,6 +379,14 @@ class EducationVerifyApi(Resource):
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
parser_edu = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("institution", type=str, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
@@ -360,6 +396,7 @@ class EducationApi(Resource):
|
||||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@api.expect(parser_edu)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -368,13 +405,7 @@ class EducationApi(Resource):
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("institution", type=str, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_edu.parse_args()
|
||||
|
||||
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
||||
|
||||
@@ -394,6 +425,14 @@ class EducationApi(Resource):
|
||||
return res
|
||||
|
||||
|
||||
parser_autocomplete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keywords", type=str, required=True, location="args")
|
||||
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
@@ -402,6 +441,7 @@ class EducationAutoCompleteApi(Resource):
|
||||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@api.expect(parser_autocomplete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -409,33 +449,30 @@ class EducationAutoCompleteApi(Resource):
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keywords", type=str, required=True, location="args")
|
||||
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_autocomplete.parse_args()
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
parser_change_email = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
.add_argument("phase", type=str, required=False, location="json")
|
||||
.add_argument("token", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@api.expect(parser_change_email)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
.add_argument("phase", type=str, required=False, location="json")
|
||||
.add_argument("token", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_change_email.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
@@ -470,20 +507,23 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_validity = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@api.expect(parser_validity)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_validity.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
@@ -514,20 +554,23 @@ class ChangeEmailCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_reset = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("new_email", type=email, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@api.expect(parser_reset)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("new_email", type=email, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_reset.parse_args()
|
||||
|
||||
if AccountService.is_account_in_freeze(args["new_email"]):
|
||||
raise AccountInFreezeError()
|
||||
@@ -555,12 +598,15 @@ class ChangeEmailResetApi(Resource):
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@api.expect(parser_check)
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_check.parse_args()
|
||||
if AccountService.is_account_in_freeze(args["email"]):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args["email"]):
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
@@ -48,22 +48,25 @@ class MemberListApi(Resource):
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_invite = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("emails", type=list, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
@api.expect(parser_invite)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("emails", type=list, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_invite.parse_args()
|
||||
|
||||
invitee_emails = args["emails"]
|
||||
invitee_role = args["role"]
|
||||
@@ -143,16 +146,19 @@ class MemberCancelInviteApi(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
@api.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id):
|
||||
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_update.parse_args()
|
||||
new_role = args["role"]
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
@@ -191,17 +197,20 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@api.expect(parser_send)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_send.parse_args()
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
@@ -229,19 +238,22 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_owner = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@api.expect(parser_owner)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_owner.parse_args()
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
@@ -276,17 +288,20 @@ class OwnerTransferCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_owner_transfer = reqparse.RequestParser().add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@api.expect(parser_owner_transfer)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_owner_transfer.parse_args()
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -14,9 +14,19 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
parser_model = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@api.expect(parser_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -24,15 +34,7 @@ class ModelProviderListApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_model.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
|
||||
@@ -40,8 +42,30 @@ class ModelProviderListApi(Resource):
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
parser_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||
)
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_delete_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@api.expect(parser_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -49,10 +73,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
@@ -61,6 +82,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@api.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -69,12 +91,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_post_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -90,6 +107,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@api.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -98,13 +116,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_put_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -121,6 +133,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -128,10 +141,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser_delete_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
@@ -141,8 +152,14 @@ class ModelProviderCredentialApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -150,10 +167,7 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
@@ -164,17 +178,20 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = reqparse.RequestParser().add_argument(
|
||||
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@api.expect(parser_validate)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_validate.parse_args()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@@ -218,8 +235,19 @@ class ModelProviderIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
parser_preferred = reqparse.RequestParser().add_argument(
|
||||
"preferred_provider_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=["system", "custom"],
|
||||
location="json",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@api.expect(parser_preferred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -230,15 +258,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"preferred_provider_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=["system", "custom"],
|
||||
location="json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_preferred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -16,23 +16,29 @@ from services.model_provider_service import ModelProviderService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
parser_get_default = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
parser_post_default = reqparse.RequestParser().add_argument(
|
||||
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@api.expect(parser_get_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_get_default.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
@@ -41,6 +47,7 @@ class DefaultModelApi(Resource):
|
||||
|
||||
return jsonable_encoder({"data": default_model_entity})
|
||||
|
||||
@api.expect(parser_post_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -50,10 +57,7 @@ class DefaultModelApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_post_default.parse_args()
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args["model_settings"]
|
||||
for model_setting in model_settings:
|
||||
@@ -84,6 +88,35 @@ class DefaultModelApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_post_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
|
||||
class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@@ -97,6 +130,7 @@ class ModelProviderModelApi(Resource):
|
||||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
@api.expect(parser_post_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -106,23 +140,7 @@ class ModelProviderModelApi(Resource):
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_post_models.parse_args()
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
@@ -160,6 +178,7 @@ class ModelProviderModelApi(Resource):
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@api.expect(parser_delete_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -169,19 +188,7 @@ class ModelProviderModelApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_delete_models.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model(
|
||||
@@ -191,29 +198,76 @@ class ModelProviderModelApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_get_credentials = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@api.expect(parser_get_credentials)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_get_credentials.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
@@ -257,6 +311,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
}
|
||||
)
|
||||
|
||||
@api.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -266,21 +321,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_post_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -304,6 +345,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@api.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -313,22 +355,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_put_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -347,6 +374,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -355,20 +383,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_delete_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
@@ -382,8 +397,24 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -392,20 +423,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
@@ -418,29 +436,32 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_model_enable_disable = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@api.expect(parser_model_enable_disable)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.enable_model(
|
||||
@@ -454,25 +475,14 @@ class ModelProviderModelEnableApi(Resource):
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@api.expect(parser_model_enable_disable)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.disable_model(
|
||||
@@ -482,28 +492,31 @@ class ModelProviderModelDisableApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@api.expect(parser_validate)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_validate.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -530,16 +543,19 @@ class ModelProviderModelValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
parser_parameter = reqparse.RequestParser().add_argument(
|
||||
"model", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@api.expect(parser_parameter)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_parameter.parse_args()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -37,19 +37,22 @@ class PluginDebuggingKeyApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_list = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@api.expect(parser_list)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_list.parse_args()
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
@@ -58,14 +61,17 @@ class PluginListApi(Resource):
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@api.expect(parser_latest)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
args = parser_latest.parse_args()
|
||||
|
||||
try:
|
||||
versions = PluginService.list_latest_versions(args["plugin_ids"])
|
||||
@@ -75,16 +81,19 @@ class PluginListLatestVersionsApi(Resource):
|
||||
return jsonable_encoder({"versions": versions})
|
||||
|
||||
|
||||
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@api.expect(parser_ids)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_ids.parse_args()
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
@@ -94,16 +103,19 @@ class PluginListInstallationsFromIdsApi(Resource):
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
parser_icon = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
.add_argument("filename", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@api.expect(parser_icon)
|
||||
@setup_required
|
||||
def get(self):
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
.add_argument("filename", type=str, required=True, location="args")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_icon.parse_args()
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
@@ -157,8 +169,17 @@ class PluginUploadFromPkgApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_github = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@api.expect(parser_github)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -166,13 +187,7 @@ class PluginUploadFromGithubApi(Resource):
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_github.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
@@ -206,19 +221,21 @@ class PluginUploadFromBundleApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkg = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@api.expect(parser_pkg)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_pkg.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
@@ -233,8 +250,18 @@ class PluginInstallFromPkgApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_githubapi = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@api.expect(parser_githubapi)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -242,14 +269,7 @@ class PluginInstallFromGithubApi(Resource):
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_githubapi.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
@@ -265,8 +285,14 @@ class PluginInstallFromGithubApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_marketplace = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@api.expect(parser_marketplace)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -274,10 +300,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_marketplace.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
@@ -292,19 +315,21 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkgapi = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@api.expect(parser_pkgapi)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_pkgapi.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@@ -319,8 +344,14 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_fetch = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@api.expect(parser_fetch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -328,10 +359,7 @@ class PluginFetchManifestApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_fetch.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@@ -345,8 +373,16 @@ class PluginFetchManifestApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_tasks = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@api.expect(parser_tasks)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -354,12 +390,7 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_tasks.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@@ -429,8 +460,16 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_marketplace_api = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@api.expect(parser_marketplace_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -438,12 +477,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_marketplace_api.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@@ -455,8 +489,19 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_github_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@api.expect(parser_github_post)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -464,15 +509,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_github_post.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@@ -489,15 +526,20 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_uninstall = reqparse.RequestParser().add_argument(
|
||||
"plugin_installation_id", type=str, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@api.expect(parser_uninstall)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
args = parser_uninstall.parse_args()
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -507,8 +549,16 @@ class PluginUninstallApi(Resource):
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_change_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("install_permission", type=str, required=True, location="json")
|
||||
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@api.expect(parser_change_post)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -518,12 +568,7 @@ class PluginChangePermissionApi(Resource):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("install_permission", type=str, required=True, location="json")
|
||||
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_change_post.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
@@ -558,8 +603,20 @@ class PluginFetchPermissionApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_dynamic = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@api.expect(parser_dynamic)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -571,16 +628,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
|
||||
user_id = current_user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_dynamic.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
@@ -599,8 +647,16 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
parser_change = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("permission", type=dict, required=True, location="json")
|
||||
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@api.expect(parser_change)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -609,12 +665,7 @@ class PluginChangePreferencesApi(Resource):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("permission", type=dict, required=True, location="json")
|
||||
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_change.parse_args()
|
||||
|
||||
permission = args["permission"]
|
||||
|
||||
@@ -694,8 +745,12 @@ class PluginFetchPreferencesApi(Resource):
|
||||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@api.expect(parser_exclude)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -703,8 +758,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
args = parser_exclude.parse_args()
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
@@ -52,8 +52,19 @@ def is_valid_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
parser_tool = reqparse.RequestParser().add_argument(
|
||||
"type",
|
||||
type=str,
|
||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="args",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-providers")
|
||||
class ToolProviderListApi(Resource):
|
||||
@api.expect(parser_tool)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -62,15 +73,7 @@ class ToolProviderListApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
req = reqparse.RequestParser().add_argument(
|
||||
"type",
|
||||
type=str,
|
||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="args",
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_tool.parse_args()
|
||||
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
|
||||
|
||||
@@ -102,8 +105,14 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
||||
parser_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@api.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -112,10 +121,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_delete.parse_args()
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
tenant_id,
|
||||
@@ -124,8 +130,17 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_add = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
class ToolBuiltinProviderAddApi(Resource):
|
||||
@api.expect(parser_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -134,13 +149,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_add.parse_args()
|
||||
|
||||
if args["type"] not in CredentialType.values():
|
||||
raise ValueError(f"Invalid credential type: {args['type']}")
|
||||
@@ -155,8 +164,17 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@api.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -168,14 +186,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_update.parse_args()
|
||||
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
@@ -213,8 +224,22 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
parser_api_add = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
||||
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/add")
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@api.expect(parser_api_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -226,19 +251,7 @@ class ToolApiProviderAddApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
||||
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_api_add.parse_args()
|
||||
|
||||
return ApiToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
@@ -254,8 +267,12 @@ class ToolApiProviderAddApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/remote")
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@api.expect(parser_remote)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -264,9 +281,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_remote.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
user_id,
|
||||
@@ -275,8 +290,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_tools = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/tools")
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@api.expect(parser_tools)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -285,11 +306,7 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_tools.parse_args()
|
||||
|
||||
return jsonable_encoder(
|
||||
ApiToolManageService.list_api_tool_provider_tools(
|
||||
@@ -300,8 +317,23 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_api_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/update")
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@api.expect(parser_api_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -313,20 +345,7 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_api_update.parse_args()
|
||||
|
||||
return ApiToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
@@ -343,8 +362,14 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_api_delete = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/delete")
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@api.expect(parser_api_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -356,11 +381,7 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_api_delete.parse_args()
|
||||
|
||||
return ApiToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
@@ -369,8 +390,12 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/get")
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@api.expect(parser_get)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -379,11 +404,7 @@ class ToolApiProviderGetApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_get.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
@@ -407,40 +428,44 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_schema = reqparse.RequestParser().add_argument(
|
||||
"schema", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/schema")
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@api.expect(parser_schema)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"schema", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_schema.parse_args()
|
||||
|
||||
return ApiToolManageService.parser_api_schema(
|
||||
schema=args["schema"],
|
||||
)
|
||||
|
||||
|
||||
parser_pre = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@api.expect(parser_pre)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_pre.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_tenant_id,
|
||||
@@ -453,8 +478,22 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
|
||||
class ToolWorkflowProviderCreateApi(Resource):
|
||||
@api.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -466,19 +505,7 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
reqparser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = reqparser.parse_args()
|
||||
args = parser_create.parse_args()
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=user_id,
|
||||
@@ -494,8 +521,22 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_workflow_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
|
||||
class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@api.expect(parser_workflow_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -507,19 +548,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
reqparser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = reqparser.parse_args()
|
||||
args = parser_workflow_update.parse_args()
|
||||
|
||||
if not args["workflow_tool_id"]:
|
||||
raise ValueError("incorrect workflow_tool_id")
|
||||
@@ -538,8 +567,14 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_workflow_delete = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
|
||||
class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@api.expect(parser_workflow_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -551,11 +586,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
reqparser = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
args = reqparser.parse_args()
|
||||
args = parser_workflow_delete.parse_args()
|
||||
|
||||
return WorkflowToolManageService.delete_workflow_tool(
|
||||
user_id,
|
||||
@@ -564,8 +595,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_wf_get = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
|
||||
class ToolWorkflowProviderGetApi(Resource):
|
||||
@api.expect(parser_wf_get)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -574,13 +613,7 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_get.parse_args()
|
||||
|
||||
if args.get("workflow_tool_id"):
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
||||
@@ -600,8 +633,14 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
return jsonable_encoder(tool)
|
||||
|
||||
|
||||
parser_wf_tools = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
|
||||
class ToolWorkflowProviderListToolApi(Resource):
|
||||
@api.expect(parser_wf_tools)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -610,11 +649,7 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_tools.parse_args()
|
||||
|
||||
return jsonable_encoder(
|
||||
WorkflowToolManageService.list_single_workflow_tools(
|
||||
@@ -790,32 +825,40 @@ class ToolOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_default_cred = reqparse.RequestParser().add_argument(
|
||||
"id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@api.expect(parser_default_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_default_cred.parse_args()
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||
)
|
||||
|
||||
|
||||
parser_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@api.expect(parser_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_custom.parse_args()
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -878,25 +921,44 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
parser_mcp = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
parser_mcp_put = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
parser_mcp_delete = reqparse.RequestParser().add_argument(
|
||||
"provider_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp")
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@api.expect(parser_mcp)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_mcp.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Parse and validate models
|
||||
@@ -921,24 +983,12 @@ class ToolProviderMCPApi(Resource):
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@api.expect(parser_mcp_put)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_mcp_put.parse_args()
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -972,14 +1022,12 @@ class ToolProviderMCPApi(Resource):
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_mcp_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_mcp_delete.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
@@ -988,18 +1036,21 @@ class ToolProviderMCPApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_auth = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@api.expect(parser_auth)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_auth.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -1097,15 +1148,18 @@ class ToolMCPUpdateApi(Resource):
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
parser_cb = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
@api.expect(parser_cb)
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_cb.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
@@ -21,6 +21,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -83,7 +84,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
@@ -149,15 +150,18 @@ class TenantApi(Resource):
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
# check if tenant_id is valid, 403 if not
|
||||
try:
|
||||
@@ -241,16 +245,19 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@api.expect(parser_info)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_info.parse_args()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
@@ -10,6 +10,7 @@ from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
@@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
|
||||
@@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
@@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -67,6 +68,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
@@ -75,7 +77,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
@@ -90,6 +91,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||
else:
|
||||
# For service API without end-user context, ensure an Account is logged in
|
||||
# so services relying on current_account_with_tenant() work correctly.
|
||||
tenant_owner_info = (
|
||||
db.session.query(Tenant, Account)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.join(Account, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
Tenant.id == app_model.tenant_id,
|
||||
TenantAccountJoin.role == "owner",
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if tenant_owner_info:
|
||||
tenant_model, account = tenant_owner_info
|
||||
account.current_tenant = tenant_model
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account not found or tenant is not active.")
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
@@ -139,7 +162,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
|
||||
@@ -37,8 +37,7 @@ def trigger_endpoint(endpoint_id: str):
|
||||
return jsonify({"error": "Endpoint not found"}), 404
|
||||
return response
|
||||
except ValueError as e:
|
||||
logger.exception("Endpoint processing failed for {endpoint_id}: {e}")
|
||||
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 500
|
||||
except Exception as e:
|
||||
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 400
|
||||
except Exception:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error"}), 500
|
||||
|
||||
@@ -88,12 +88,6 @@ class AudioApi(WebApiResource):
|
||||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
text_to_audio_response_fields = {
|
||||
"audio_url": fields.String,
|
||||
"duration": fields.Float,
|
||||
}
|
||||
|
||||
@marshal_with(text_to_audio_response_fields)
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
|
||||
@@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class AppRunner:
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: str | None = None,
|
||||
query: str = "",
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -105,7 +105,7 @@ class AppRunner:
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
|
||||
@@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
|
||||
@@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
||||
@@ -39,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@@ -139,8 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||
# start node get inputs
|
||||
# for trigger debug run, not prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
|
||||
@@ -44,6 +44,9 @@ class InvokeFrom(StrEnum):
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
@@ -110,6 +113,11 @@ class AppGenerateEntity(BaseModel):
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[File]
|
||||
|
||||
# Unique identifier of the user initiating the execution.
|
||||
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
|
||||
#
|
||||
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
@@ -135,7 +143,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: str | None = None
|
||||
query: str = ""
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -1,15 +1,64 @@
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from models.model import AppMode
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
# Wrapper types for `WorkflowAppGenerateEntity` and
|
||||
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
|
||||
# and correct reconstruction of the entity field during (de)serialization.
|
||||
class _WorkflowGenerateEntityWrapper(BaseModel):
|
||||
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
|
||||
entity: WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
|
||||
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
|
||||
entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
_GenerateEntityUnion: TypeAlias = Annotated[
|
||||
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class WorkflowResumptionContext(BaseModel):
|
||||
"""WorkflowResumptionContext captures all state necessary for resumption."""
|
||||
|
||||
version: Literal["1"] = "1"
|
||||
|
||||
# Only workflow / chatflow could be paused.
|
||||
generate_entity: _GenerateEntityUnion
|
||||
serialized_graph_runtime_state: str
|
||||
|
||||
def dumps(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def loads(cls, value: str) -> Self:
|
||||
return cls.model_validate_json(value)
|
||||
|
||||
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
|
||||
return self.generate_entity.entity
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, session_factory: Engine | sessionmaker[Session], state_owner_user_id: str):
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Engine | sessionmaker[Session],
|
||||
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
|
||||
state_owner_user_id: str,
|
||||
):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
@@ -19,6 +68,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
self._generate_entity = generate_entity
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
@@ -49,13 +99,25 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
else:
|
||||
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
state = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
|
||||
generate_entity=entity_wrapper,
|
||||
)
|
||||
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=self.graph_runtime_state.dumps(),
|
||||
state=state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
|
||||
@@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
@@ -140,7 +140,27 @@ class MessageCycleManager:
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
|
||||
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
|
||||
|
||||
# Add new unique resources from the event
|
||||
for resource in event.retriever_resources or []:
|
||||
if not resource:
|
||||
continue
|
||||
|
||||
is_duplicate = (
|
||||
resource.dataset_id
|
||||
and resource.document_id
|
||||
and (resource.dataset_id, resource.document_id) in existing_ids
|
||||
)
|
||||
|
||||
if not is_duplicate:
|
||||
merged_resources.append(resource)
|
||||
|
||||
for i, resource in enumerate(merged_resources, 1):
|
||||
resource.position = i
|
||||
|
||||
self._task_state.metadata.retriever_resources = merged_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
|
||||
"""
|
||||
|
||||
15
api/core/entities/document_task.py
Normal file
15
api/core/entities/document_task.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: Sequence[str]
|
||||
@@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Return composite sort key: (model_type value, model position index)
|
||||
return (model.model_type.value, position_index)
|
||||
|
||||
# Deduplicate
|
||||
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
|
||||
|
||||
# Sort using the composite sort key
|
||||
return sorted(provider_models, key=get_sort_key)
|
||||
|
||||
|
||||
@@ -74,6 +74,10 @@ class File(BaseModel):
|
||||
storage_key: str | None = None,
|
||||
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
|
||||
url: str | None = None,
|
||||
# Legacy compatibility fields - explicitly handle known extra fields
|
||||
tool_file_id: str | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
datasource_file_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
|
||||
@@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(
|
||||
f"""
|
||||
// declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
// decode and prepare input object
|
||||
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
|
||||
@@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
var output_json = JSON.stringify(output_obj)
|
||||
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
|
||||
console.log(result)
|
||||
"""
|
||||
)
|
||||
""")
|
||||
return runner_script
|
||||
|
||||
@@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
# declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
@@ -138,6 +138,10 @@ class StreamableHTTPTransport:
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
# ping event send by server will be recognized as a message event with empty data by httpx-sse's SSEDecoder
|
||||
if not sse.data.strip():
|
||||
return False
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug("SSE message: %s", message)
|
||||
|
||||
@@ -52,7 +52,7 @@ class OpenAIModeration(Moderation):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
|
||||
)
|
||||
|
||||
openai_moderation = model_instance.invoke_moderation(text=text)
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
from sqlalchemy import select
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
@@ -30,9 +31,10 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,22 +101,45 @@ def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def string_to_trace_id128(string: str | None) -> int:
|
||||
"""
|
||||
Convert any input string into a stable 128-bit integer trace ID.
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
string_stacktrace = "".join(traceback.format_exception(error))
|
||||
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
|
||||
else:
|
||||
error_message = str(error)
|
||||
return error_message
|
||||
|
||||
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
|
||||
It's suitable for generating consistent, unique identifiers from strings.
|
||||
"""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash digest
|
||||
digest = hash_object.digest()[:16]
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
|
||||
# Convert to a 128-bit integer
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
if isinstance(error, Exception):
|
||||
current_span.record_exception(error)
|
||||
else:
|
||||
exception_type = error.__class__.__name__
|
||||
exception_message = str(error)
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
current_span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
@@ -131,9 +156,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
@@ -151,7 +179,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
|
||||
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
@@ -166,15 +194,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
@@ -186,31 +208,58 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
for node_execution in workflow_node_executions:
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
node_metadata.update(
|
||||
{
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
|
||||
}
|
||||
)
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
@@ -223,8 +272,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
@@ -236,17 +286,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -260,11 +313,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
)
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
@@ -275,8 +325,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
if node_execution.status == "failed":
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
@@ -322,34 +380,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
@@ -383,26 +425,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
message_span_context = set_span_in_context(message_span)
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=message_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(llm_span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(llm_span)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(message_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(message_span)
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
@@ -418,15 +460,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
@@ -445,19 +481,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
@@ -480,15 +511,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
@@ -499,19 +524,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
@@ -533,15 +553,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
@@ -554,19 +568,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
@@ -580,20 +589,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
@@ -612,19 +610,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
@@ -641,15 +634,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
@@ -663,22 +650,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
@@ -698,26 +697,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
|
||||
@@ -5,6 +5,7 @@ Tencent APM Trace Client - handles network operations, metrics, and API communic
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
@@ -110,6 +111,7 @@ class TencentTraceClient:
|
||||
self.span_contexts: dict[int, trace_api.SpanContext] = {}
|
||||
|
||||
self.meter: Meter | None = None
|
||||
self.meter_provider: MeterProvider | None = None
|
||||
self.hist_llm_duration: Histogram | None = None
|
||||
self.hist_token_usage: Histogram | None = None
|
||||
self.hist_time_to_first_token: Histogram | None = None
|
||||
@@ -119,7 +121,6 @@ class TencentTraceClient:
|
||||
|
||||
# Metrics exporter and instruments
|
||||
try:
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.sdk.metrics import Histogram, MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||
|
||||
@@ -202,9 +203,11 @@ class TencentTraceClient:
|
||||
)
|
||||
|
||||
if metric_reader is not None:
|
||||
# Use instance-level MeterProvider instead of global to support config changes
|
||||
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
|
||||
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(provider)
|
||||
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
|
||||
self.meter_provider = provider
|
||||
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
|
||||
|
||||
# LLM operation duration histogram
|
||||
self.hist_llm_duration = self.meter.create_histogram(
|
||||
@@ -244,6 +247,7 @@ class TencentTraceClient:
|
||||
self.metric_reader = metric_reader
|
||||
else:
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
@@ -253,6 +257,7 @@ class TencentTraceClient:
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
@@ -279,6 +284,14 @@ class TencentTraceClient:
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
LLM_OPERATION_DURATION,
|
||||
latency_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
|
||||
@@ -317,6 +330,13 @@ class TencentTraceClient:
|
||||
"server.address": server_address,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
token_count,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
|
||||
@@ -344,6 +364,13 @@ class TencentTraceClient:
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
ttft_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
|
||||
@@ -371,6 +398,13 @@ class TencentTraceClient:
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
ttg_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
|
||||
@@ -390,6 +424,14 @@ class TencentTraceClient:
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_TRACE_DURATION,
|
||||
duration_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
|
||||
@@ -474,11 +516,19 @@ class TencentTraceClient:
|
||||
|
||||
if self.tracer_provider:
|
||||
self.tracer_provider.shutdown()
|
||||
|
||||
# Shutdown instance-level meter provider
|
||||
if self.meter_provider is not None:
|
||||
try:
|
||||
self.meter_provider.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
|
||||
|
||||
if self.metric_reader is not None:
|
||||
try:
|
||||
self.metric_reader.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error during client shutdown")
|
||||
|
||||
@@ -246,7 +246,7 @@ class RequestFetchAppInfo(BaseModel):
|
||||
|
||||
class TriggerInvokeEventResponse(BaseModel):
|
||||
variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
cancelled: bool | None = False
|
||||
cancelled: bool = Field(default=False)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@@ -147,7 +147,8 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return cast(str, info["version"]["number"])
|
||||
# remove any suffix like "-SNAPSHOT" from the version string
|
||||
return cast(str, info["version"]["number"]).split("-")[0]
|
||||
|
||||
def _check_version(self):
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
|
||||
@@ -39,11 +39,13 @@ class WeaviateConfig(BaseModel):
|
||||
|
||||
Attributes:
|
||||
endpoint: Weaviate server endpoint URL
|
||||
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
|
||||
api_key: Optional API key for authentication
|
||||
batch_size: Number of objects to batch per insert operation
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
grpc_endpoint: str | None = None
|
||||
api_key: str | None = None
|
||||
batch_size: int = 100
|
||||
|
||||
@@ -88,9 +90,22 @@ class WeaviateVector(BaseVector):
|
||||
http_secure = p.scheme == "https"
|
||||
http_port = p.port or (443 if http_secure else 80)
|
||||
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
)
|
||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||
grpc_host = grpc_p.hostname or "localhost"
|
||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||
grpc_secure = grpc_p.scheme == "grpcs"
|
||||
else:
|
||||
# Infer from HTTP endpoint as fallback
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
|
||||
client = weaviate.connect_to_custom(
|
||||
http_host=host,
|
||||
@@ -432,6 +447,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
|
||||
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
|
||||
@@ -152,13 +152,15 @@ class WordExtractor(BaseExtractor):
|
||||
# Initialize a row, all of which are empty by default
|
||||
row_cells = [""] * total_cols
|
||||
col_index = 0
|
||||
for cell in row.cells:
|
||||
while col_index < len(row.cells):
|
||||
# make sure the col_index is not out of range
|
||||
while col_index < total_cols and row_cells[col_index] != "":
|
||||
while col_index < len(row.cells) and row_cells[col_index] != "":
|
||||
col_index += 1
|
||||
# if col_index is out of range the loop is jumped
|
||||
if col_index >= total_cols:
|
||||
if col_index >= len(row.cells):
|
||||
break
|
||||
# get the correct cell
|
||||
cell = row.cells[col_index]
|
||||
cell_content = self._parse_cell(cell, image_map).strip()
|
||||
cell_colspan = cell.grid_span or 1
|
||||
for i in range(cell_colspan):
|
||||
|
||||
0
api/core/rag/pipeline/__init__.py
Normal file
0
api/core/rag/pipeline/__init__.py
Normal file
82
api/core/rag/pipeline/queue.py
Normal file
82
api/core/rag/pipeline/queue.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
|
||||
|
||||
|
||||
class TaskWrapper(BaseModel):
|
||||
data: Any
|
||||
|
||||
def serialize(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||||
return cls.model_validate_json(serialized_data)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueue:
|
||||
"""
|
||||
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
|
||||
It uses Redis list to store tasks, and Redis key to store task waiting flag.
|
||||
Support tasks that can be serialized by json.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, unique_key: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._unique_key = unique_key
|
||||
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
|
||||
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
|
||||
|
||||
def get_task_key(self):
|
||||
return redis_client.get(self._task_key)
|
||||
|
||||
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
|
||||
redis_client.setex(self._task_key, ttl, 1)
|
||||
|
||||
def delete_task_key(self):
|
||||
redis_client.delete(self._task_key)
|
||||
|
||||
def push_tasks(self, tasks: Sequence[Any]):
|
||||
serialized_tasks = []
|
||||
for task in tasks:
|
||||
# Store str list directly, maintaining full compatibility for pipeline scenarios
|
||||
if isinstance(task, str):
|
||||
serialized_tasks.append(task)
|
||||
else:
|
||||
# Use TaskWrapper to do JSON serialization for non-string tasks
|
||||
wrapper = TaskWrapper(data=task)
|
||||
serialized_data = wrapper.serialize()
|
||||
serialized_tasks.append(serialized_data)
|
||||
|
||||
if not serialized_tasks:
|
||||
return
|
||||
|
||||
redis_client.lpush(self._queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
tasks = []
|
||||
for _ in range(count):
|
||||
serialized_task = redis_client.rpop(self._queue)
|
||||
if not serialized_task:
|
||||
break
|
||||
|
||||
if isinstance(serialized_task, bytes):
|
||||
serialized_task = serialized_task.decode("utf-8")
|
||||
|
||||
try:
|
||||
wrapper = TaskWrapper.deserialize(serialized_task)
|
||||
tasks.append(wrapper.data)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
|
||||
# Fall back to raw string for legacy format or invalid JSON
|
||||
tasks.append(serialized_task)
|
||||
|
||||
return tasks
|
||||
@@ -7,8 +7,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy import and_, or_, select
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -1023,60 +1022,55 @@ class DatasetRetrieval:
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return
|
||||
return filters
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
|
||||
@@ -210,12 +210,13 @@ class Tool(ABC):
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
type=ToolInvokeMessage.MessageType.JSON,
|
||||
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
|
||||
)
|
||||
|
||||
def create_variable_message(
|
||||
|
||||
@@ -129,6 +129,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
@@ -52,6 +55,11 @@ class MCPTool(Tool):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
@@ -97,6 +105,10 @@ class MCPTool(Tool):
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@@ -228,29 +228,41 @@ class ToolEngine:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ""
|
||||
parts: list[str] = []
|
||||
json_parts: list[str] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
||||
parts.append(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += (
|
||||
parts.append(
|
||||
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
result += (
|
||||
parts.append(
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += json.dumps(
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
|
||||
if json_message.suppress_output:
|
||||
continue
|
||||
json_parts.append(
|
||||
json.dumps(
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
parts.append(str(response.message))
|
||||
|
||||
return result
|
||||
# Add JSON parts, avoiding duplicates from text parts.
|
||||
if json_parts:
|
||||
existing_parts = set(parts)
|
||||
parts.extend(p for p in json_parts if p not in existing_parts)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary_and_text(
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@@ -32,7 +33,6 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
@@ -618,12 +618,28 @@ class ToolManager:
|
||||
"""
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use DISTINCT ON
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
else:
|
||||
# MySQL: Use window function to achieve same result
|
||||
sql = """
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY tenant_id, provider
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
) as rn
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
) ranked WHERE rn = 1
|
||||
"""
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@@ -117,7 +117,7 @@ class WorkflowTool(Tool):
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
yield self.create_json_message(outputs, suppress_output=True)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
|
||||
@@ -208,7 +208,7 @@ class SubscriptionBuilder(BaseModel):
|
||||
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
|
||||
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] = Field(..., description="The credentials of the subscription builder")
|
||||
credentials: Mapping[str, Any] = Field(..., description="The credentials of the subscription builder")
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
credential_expires_at: int | None = Field(
|
||||
default=None, description="The credential expires at of the subscription builder"
|
||||
@@ -227,7 +227,7 @@ class SubscriptionBuilderUpdater(BaseModel):
|
||||
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] | None = Field(
|
||||
credentials: Mapping[str, Any] | None = Field(
|
||||
default=None, description="The credentials of the subscription builder"
|
||||
)
|
||||
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||
|
||||
@@ -13,14 +13,14 @@ import contexts
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError, PluginNotFoundError
|
||||
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
|
||||
from core.plugin.impl.trigger import PluginTriggerClient
|
||||
from core.trigger.entities.entities import (
|
||||
EventEntity,
|
||||
Subscription,
|
||||
UnsubscribeResult,
|
||||
)
|
||||
from core.trigger.errors import EventIgnoreError, TriggerPluginInvokeError
|
||||
from core.trigger.errors import EventIgnoreError
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
@@ -189,13 +189,10 @@ class TriggerManager:
|
||||
request=request,
|
||||
payload=payload,
|
||||
)
|
||||
except EventIgnoreError as e:
|
||||
except EventIgnoreError:
|
||||
return TriggerInvokeEventResponse(variables={}, cancelled=True)
|
||||
except PluginInvokeError as e:
|
||||
logger.exception("Failed to invoke trigger event")
|
||||
raise TriggerPluginInvokeError(
|
||||
description=e.to_user_friendly_error(plugin_name=provider.entity.identity.name)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def subscribe_trigger(
|
||||
|
||||
@@ -202,6 +202,35 @@ class SegmentType(StrEnum):
|
||||
raise ValueError(f"element_type is only supported by array type, got {self}")
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: "SegmentType"):
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
match t:
|
||||
case (
|
||||
SegmentType.ARRAY_OBJECT
|
||||
| SegmentType.ARRAY_ANY
|
||||
| SegmentType.ARRAY_STRING
|
||||
| SegmentType.ARRAY_NUMBER
|
||||
| SegmentType.ARRAY_BOOLEAN
|
||||
):
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.INTEGER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.FLOAT:
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return variable_factory.build_segment(False)
|
||||
case _:
|
||||
raise ValueError(f"unsupported variable type: {t}")
|
||||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
|
||||
@@ -114,9 +114,45 @@ class GraphValidator:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _TriggerStartExclusivityValidator:
|
||||
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
|
||||
|
||||
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
start_node_id: str | None = None
|
||||
trigger_node_ids: list[str] = []
|
||||
|
||||
for node in graph.nodes.values():
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if not isinstance(node_type, NodeType):
|
||||
continue
|
||||
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = node.id
|
||||
elif node_type.is_trigger_node:
|
||||
trigger_node_ids.append(node.id)
|
||||
|
||||
if start_node_id and trigger_node_ids:
|
||||
trigger_list = ", ".join(trigger_node_ids)
|
||||
return [
|
||||
GraphValidationIssue(
|
||||
code=self.conflict_code,
|
||||
message=(
|
||||
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
|
||||
),
|
||||
node_id=start_node_id,
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
_TriggerStartExclusivityValidator(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from uuid import uuid4
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -108,8 +107,8 @@ class Worker(threading.Thread):
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="unknown",
|
||||
node_type=NodeType.CODE,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
|
||||
@@ -6,12 +6,12 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetDocument
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@@ -597,79 +597,79 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "=" | "is":
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
@@ -12,7 +11,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
@@ -116,7 +114,7 @@ class VariableAssignerNode(Node):
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
# Over write the variable.
|
||||
@@ -143,24 +141,3 @@ class VariableAssignerNode(Node):
|
||||
process_data=common_helpers.set_updated_variables({}, updated_variables),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.INTEGER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.FLOAT:
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from core.variables import SegmentType
|
||||
|
||||
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
@@ -16,7 +16,6 @@ from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNod
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
@@ -249,7 +248,7 @@ class VariableAssignerNode(Node):
|
||||
case Operation.OVER_WRITE:
|
||||
return value
|
||||
case Operation.CLEAR:
|
||||
return EMPTY_VALUE_MAPPING[variable.value_type]
|
||||
return SegmentType.get_zero_value(variable.value_type).to_object()
|
||||
case Operation.APPEND:
|
||||
return variable.value + [value]
|
||||
case Operation.EXTEND:
|
||||
|
||||
@@ -153,7 +153,11 @@ class VariablePool(BaseModel):
|
||||
return None
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
segment: Segment | None = self.variable_dictionary[node_id].get(name)
|
||||
node_map = self.variable_dictionary.get(node_id)
|
||||
if node_map is None:
|
||||
return None
|
||||
|
||||
segment: Segment | None = node_map.get(name)
|
||||
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
@@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
0
api/enums/__init__.py
Normal file
0
api/enums/__init__.py
Normal file
15
api/enums/cloud_plan.py
Normal file
15
api/enums/cloud_plan.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class CloudPlan(StrEnum):
|
||||
"""
|
||||
Enum representing user plan types in the cloud platform.
|
||||
|
||||
SANDBOX: Free/default plan with limited features
|
||||
PROFESSIONAL: Professional paid plan
|
||||
TEAM: Team collaboration paid plan
|
||||
"""
|
||||
|
||||
SANDBOX = auto()
|
||||
PROFESSIONAL = auto()
|
||||
TEAM = auto()
|
||||
@@ -116,6 +116,7 @@ app_partial_fields = {
|
||||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
|
||||
134
api/libs/broadcast_channel/channel.py
Normal file
134
api/libs/broadcast_channel/channel.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Broadcast channel for Pub/Sub messaging.
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Protocol, Self
|
||||
|
||||
|
||||
class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||
"""A subscription to a topic that provides an iterator over received messages.
|
||||
The subscription can be used as a context manager and will automatically
|
||||
close when exiting the context.
|
||||
|
||||
Note: `Subscription` instances are not thread-safe. Each thread should create its own
|
||||
subscription.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""`__iter__` returns an iterator used to consume the message from this subscription.
|
||||
|
||||
If the caller did not enter the context, `__iter__` may lazily perform the setup before
|
||||
yielding messages; otherwise `__enter__` handles it.”
|
||||
|
||||
If the subscription is closed, then the returned iterator exits without
|
||||
raising any error.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""close closes the subscription, releases any resources associated with it."""
|
||||
...
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the broadcast channel.
|
||||
|
||||
If `timeout` is specified, this method returns `None` if no message is
|
||||
received within the given period. If `timeout` is `None`, the call blocks
|
||||
until a message is received.
|
||||
|
||||
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
|
||||
cancel a blocking subscription.
|
||||
|
||||
:param timeout: timeout for receive message, in seconds.
|
||||
|
||||
Returns:
|
||||
bytes: The received message as a byte string, or
|
||||
None: If the timeout expires before a message is received.
|
||||
|
||||
Raises:
|
||||
SubscriptionClosed: If the subscription has already been closed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Producer(Protocol):
|
||||
"""Producer is an interface for message publishing. It is already bound to a specific topic.
|
||||
|
||||
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, payload: bytes) -> None:
|
||||
"""Publish a message to the bounded topic."""
|
||||
...
|
||||
|
||||
|
||||
class Subscriber(Protocol):
|
||||
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
|
||||
|
||||
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(self) -> Subscription:
|
||||
pass
|
||||
|
||||
|
||||
class Topic(Producer, Subscriber, Protocol):
|
||||
"""A named channel for publishing and subscribing to messages.
|
||||
|
||||
Topics provide both read and write access. For restricted access,
|
||||
use as_producer() for write-only view or as_subscriber() for read-only view.
|
||||
|
||||
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def as_producer(self) -> Producer:
|
||||
"""as_producer creates a write-only view for this topic."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
"""as_subscriber create a read-only view for this topic."""
|
||||
...
|
||||
|
||||
|
||||
class BroadcastChannel(Protocol):
|
||||
"""A broadcasting channel is a channel supporting broadcasting semantics.
|
||||
|
||||
Each channel is identified by a topic, different topics are isolated and do not affect each other.
|
||||
|
||||
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
|
||||
a specific topic, all subscription should receive the published message.
|
||||
|
||||
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||
should receive all subsequent messages published.
|
||||
|
||||
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
"""topic returns a `Topic` instance for the given topic name."""
|
||||
...
|
||||
12
api/libs/broadcast_channel/exc.py
Normal file
12
api/libs/broadcast_channel/exc.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class BroadcastChannelError(Exception):
|
||||
"""`BroadcastChannelError` is the base class for all exceptions related
|
||||
to `BroadcastChannel`."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptionClosedError(BroadcastChannelError):
|
||||
"""SubscriptionClosedError means that the subscription has been closed and
|
||||
methods for consuming messages should not be called."""
|
||||
|
||||
pass
|
||||
3
api/libs/broadcast_channel/redis/__init__.py
Normal file
3
api/libs/broadcast_channel/redis/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .channel import BroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel"]
|
||||
200
api/libs/broadcast_channel/redis/channel.py
Normal file
200
api/libs/broadcast_channel/redis/channel.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation.
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels.
|
||||
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisSubscription(Subscription):
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
||||
|
||||
self._pubsub.subscribe(self._topic)
|
||||
_logger.debug("Subscribed to channel %s", self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != "message":
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
||||
pubsub.unsubscribe(self._topic)
|
||||
pubsub.close()
|
||||
_logger.debug("PubSub closed for topic %s", self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
||||
# method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
@@ -2,6 +2,8 @@ import abc
|
||||
import datetime
|
||||
from typing import Protocol
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
class _NowFunction(Protocol):
|
||||
@abc.abstractmethod
|
||||
@@ -31,3 +33,51 @@ def ensure_naive_utc(dt: datetime.datetime) -> datetime.datetime:
|
||||
if dt.tzinfo is None:
|
||||
return dt
|
||||
return dt.astimezone(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
def parse_time_range(
|
||||
start: str | None, end: str | None, tzname: str
|
||||
) -> tuple[datetime.datetime | None, datetime.datetime | None]:
|
||||
"""
|
||||
Parse time range strings and convert to UTC datetime objects.
|
||||
Handles DST ambiguity and non-existent times gracefully.
|
||||
|
||||
Args:
|
||||
start: Start time string (YYYY-MM-DD HH:MM)
|
||||
end: End time string (YYYY-MM-DD HH:MM)
|
||||
tzname: Timezone name
|
||||
|
||||
Returns:
|
||||
tuple: (start_datetime_utc, end_datetime_utc)
|
||||
|
||||
Raises:
|
||||
ValueError: When time range is invalid or start > end
|
||||
"""
|
||||
tz = pytz.timezone(tzname)
|
||||
utc = pytz.utc
|
||||
|
||||
def _parse(time_str: str | None, label: str) -> datetime.datetime | None:
|
||||
if not time_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
dt = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M").replace(second=0)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid {label} time format: {e}")
|
||||
|
||||
try:
|
||||
return tz.localize(dt, is_dst=None).astimezone(utc)
|
||||
except pytz.AmbiguousTimeError:
|
||||
return tz.localize(dt, is_dst=False).astimezone(utc)
|
||||
except pytz.NonExistentTimeError:
|
||||
dt += datetime.timedelta(hours=1)
|
||||
return tz.localize(dt, is_dst=None).astimezone(utc)
|
||||
|
||||
start_dt = _parse(start, "start")
|
||||
end_dt = _parse(end, "end")
|
||||
|
||||
# Range validation
|
||||
if start_dt and end_dt and start_dt > end_dt:
|
||||
raise ValueError("start must be earlier than or equal to end")
|
||||
|
||||
return start_dt, end_dt
|
||||
|
||||
@@ -177,6 +177,15 @@ def timezone(timezone_string):
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
||||
if dify_config.DB_TYPE == "postgresql":
|
||||
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
|
||||
elif dify_config.DB_TYPE == "mysql":
|
||||
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
|
||||
|
||||
|
||||
def generate_string(n):
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user