mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 08:09:24 +08:00
Compare commits
63 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4678845dd | ||
|
|
174ebb51db | ||
|
|
626c78a690 | ||
|
|
9eaae770a6 | ||
|
|
ca60610306 | ||
|
|
082f8b17ab | ||
|
|
cf93d8d6e2 | ||
|
|
aae2fb8a30 | ||
|
|
23e52f14e3 | ||
|
|
c5b68fb273 | ||
|
|
6f17c9b2fe | ||
|
|
c98311b325 | ||
|
|
d44d4bd6fd | ||
|
|
2adaceab82 | ||
|
|
d979955c8a | ||
|
|
eae670ea4a | ||
|
|
b5825142d1 | ||
|
|
741e9303d4 | ||
|
|
538e3fc256 | ||
|
|
ba3dc8cae0 | ||
|
|
ae7c0380dc | ||
|
|
23e3413655 | ||
|
|
4fdb37771a | ||
|
|
94b54b7ca9 | ||
|
|
f9412f5fdb | ||
|
|
1d6829f400 | ||
|
|
f8bae897e5 | ||
|
|
dd1172b57e | ||
|
|
67d326a558 | ||
|
|
fe747040bc | ||
|
|
7d6c925cbc | ||
|
|
f488d06b20 | ||
|
|
c00a19ced3 | ||
|
|
e9810a6df2 | ||
|
|
cae15013e0 | ||
|
|
52c84da051 | ||
|
|
026f0bfce9 | ||
|
|
d19181fb29 | ||
|
|
2f9de2229f | ||
|
|
34f55739e0 | ||
|
|
668b059c07 | ||
|
|
753e5f1500 | ||
|
|
a6af8e5d8f | ||
|
|
3e1d5ac51b | ||
|
|
b0091452ca | ||
|
|
eff115267f | ||
|
|
07cde4f8fe | ||
|
|
9f28a48a92 | ||
|
|
0d3cd3b16a | ||
|
|
3dc82fb044 | ||
|
|
cb6e73347e | ||
|
|
ecd6cbaee6 | ||
|
|
d54e942264 | ||
|
|
28ba721455 | ||
|
|
784dd7848e | ||
|
|
e2a5f8ba1a | ||
|
|
8e11200306 | ||
|
|
7599f79a17 | ||
|
|
510389909c | ||
|
|
2c6e00174b | ||
|
|
24f3456990 | ||
|
|
20514ff288 | ||
|
|
381d255290 |
@@ -19,7 +19,7 @@ def check_file_for_chinese_comments(file_path):
|
||||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py']
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
|
||||
36
LICENSE
36
LICENSE
@@ -1,26 +1,26 @@
|
||||
# Dify Open Source License
|
||||
|
||||
The Dify project uses a combination of the Apache License 2.0, MIT License, and an additional agreement to protect against direct competition with Dify Cloud services.
|
||||
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
As a contributor, you should agree that your contributed code:
|
||||
a. Might be subject to a more permissive open source license in the future.
|
||||
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
|
||||
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
|
||||
|
||||
Please contact business@dify.ai by email to inquire about licensing matters.
|
||||
|
||||
2. As a contributor, you should agree that your contributed code:
|
||||
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed.
|
||||
b. Can be used for commercial purposes, such as Dify's cloud business.
|
||||
|
||||
The following components are open source under the MIT license, allowing you to build and develop applications based on them:
|
||||
- WebApp elements, e.g., web/app/components/share
|
||||
- Derived WebApp Template projects
|
||||
|
||||
The remaining parts of the project are open source under the Apache License 2.0.
|
||||
|
||||
With the Apache License 2.0, MIT License, and this supplementary agreement, anyone can freely use, modify, and distribute Dify, provided that:
|
||||
|
||||
- If you use Dify solely as a backend service for other applications, no authorization is needed for commercial or closed source purposes.
|
||||
- If you wish to use Dify for commercial and closed source SaaS services similar to Dify Cloud, please contact us for authorization.
|
||||
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2023 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -34,13 +34,3 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
----------
|
||||
The MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
12
README.md
12
README.md
@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
|
||||
Visual data analysis, log review, and annotation for applications
|
||||
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo(ChatGPT)
|
||||
- GPT-4
|
||||
* **OpenAI** :GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
|
||||
|
||||
* **Azure OpenAI**
|
||||
|
||||
* **Antropic**:Claude2、Claude-instant
|
||||
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
|
||||
try it now.
|
||||
|
||||
* **hugging face Hub**:Coming soon.
|
||||
|
||||
## Use Cloud Services
|
||||
|
||||
|
||||
13
README_CN.md
13
README_CN.md
@@ -17,11 +17,16 @@
|
||||
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
|
||||
- 可视化的对应用进行数据分析,查阅日志或进行标注
|
||||
|
||||
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前已支持:
|
||||
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo(ChatGPT)
|
||||
- GPT-4
|
||||
* **OpenAI**:GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
|
||||
|
||||
* **Azure OpenAI Service**
|
||||
* **Anthropic**:Claude2、Claude-instant
|
||||
|
||||
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
|
||||
|
||||
* **Hugging Face Hub**(即将推出)
|
||||
|
||||
## 使用云服务
|
||||
|
||||
|
||||
@@ -27,4 +27,4 @@ RUN chmod +x /entrypoint.sh
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA ${COMMIT_SHA}
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
@@ -2,6 +2,8 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
@@ -20,14 +22,14 @@ from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import model, account, dataset, web, task, source
|
||||
from models import model, account, dataset, web, task, source, tool
|
||||
from events import event_handlers
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
import core
|
||||
from config import Config, CloudEditionConfig
|
||||
from commands import register_commands
|
||||
from models.account import TenantAccountJoin
|
||||
from models.account import TenantAccountJoin, AccountStatus
|
||||
from models.model import Account, EndUser, App
|
||||
|
||||
import warnings
|
||||
@@ -101,6 +103,9 @@ def load_user(user_id):
|
||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if account:
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise Forbidden('Account is banned or closed.')
|
||||
|
||||
workspace_id = session.get('workspace_id')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
@@ -13,12 +14,13 @@ from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
|
||||
from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider
|
||||
from models.provider import Provider, ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -171,7 +173,7 @@ def recreate_all_dataset_indexes():
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
@@ -187,15 +189,103 @@ def recreate_all_dataset_indexes():
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
click.echo(
|
||||
click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
|
||||
def clean_unused_dataset_indexes():
|
||||
click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
|
||||
clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
|
||||
start_at = time.perf_counter()
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
dataset_query = db.session.query(DatasetQuery).filter(
|
||||
DatasetQuery.created_at > thirty_days_ago,
|
||||
DatasetQuery.dataset_id == dataset.id
|
||||
).all()
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.indexing_status == 'completed',
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > thirty_days_ago
|
||||
).all()
|
||||
if not documents or len(documents) == 0:
|
||||
try:
|
||||
# remove index
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
# update document
|
||||
update_params = {
|
||||
Document.enabled: False
|
||||
}
|
||||
|
||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.commit()
|
||||
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
|
||||
|
||||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for tenant in tenants:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style(
|
||||
'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(generate_invitation_codes)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
|
||||
@@ -50,7 +50,11 @@ DEFAULTS = {
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'DEFAULT_LLM_PROVIDER': 'openai'
|
||||
'DEFAULT_LLM_PROVIDER': 'openai',
|
||||
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
|
||||
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +90,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.8"
|
||||
self.CURRENT_VERSION = "0.3.12"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -191,6 +195,10 @@ class Config:
|
||||
|
||||
# hosted provider credentials
|
||||
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
||||
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
|
||||
|
||||
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
|
||||
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
@@ -207,6 +215,9 @@ class Config:
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -18,7 +18,10 @@ from .auth import login, oauth, data_source_oauth, activate
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, providers, account
|
||||
from .workspace import workspace, members, model_providers, account, tool_providers
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
||||
@@ -24,6 +24,7 @@ model_config_fields = {
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'pre_prompt': fields.String,
|
||||
@@ -96,7 +97,8 @@ class AppListApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id).order_by(App.created_at.desc()),
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False)
|
||||
@@ -147,6 +149,7 @@ class AppListApi(Resource):
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
@@ -438,6 +441,7 @@ class AppCopy(Resource):
|
||||
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
|
||||
speech_to_text=app_config.speech_to_text,
|
||||
more_like_this=app_config.more_like_this,
|
||||
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
|
||||
model=app_config.model,
|
||||
user_input_form=app_config.user_input_form,
|
||||
pre_prompt=app_config.pre_prompt,
|
||||
|
||||
@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -95,6 +95,7 @@ class CompletionConversationApi(Resource):
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(attribute='end_user.session_id'),
|
||||
'from_account_id': fields.String,
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
@@ -135,6 +136,8 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion')
|
||||
|
||||
query = query.options(joinedload(Conversation.end_user))
|
||||
|
||||
if args['keyword']:
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id
|
||||
@@ -160,7 +163,7 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
@@ -246,6 +249,7 @@ class ChatConversationApi(Resource):
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(attribute='end_user.session_id'),
|
||||
'from_account_id': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
@@ -288,6 +292,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat')
|
||||
|
||||
query = query.options(joinedload(Conversation.end_user))
|
||||
|
||||
if args['keyword']:
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id
|
||||
@@ -316,7 +322,7 @@ class ChatConversationApi(Resource):
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = 'provider_quota_exceeded'
|
||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
code = 400
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
|
||||
args['audiences'],
|
||||
args['hoping_to_solve']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
|
||||
@@ -3,7 +3,6 @@ from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
@@ -221,6 +220,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
@@ -235,12 +235,12 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
|
||||
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
|
||||
args['process_rule'])
|
||||
args['process_rule'], args['doc_form'])
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
||||
@@ -60,6 +60,7 @@ document_fields = {
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'doc_form': fields.String,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
@@ -86,6 +87,7 @@ document_with_segments_fields = {
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@@ -269,6 +271,7 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument('process_rule', type=dict, required=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@@ -279,8 +282,8 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -313,6 +316,7 @@ class DatasetInitApi(Resource):
|
||||
nullable=False, location='json')
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate args
|
||||
@@ -324,8 +328,8 @@ class DatasetInitApi(Resource):
|
||||
document_data=args,
|
||||
account=current_user
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -488,6 +492,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
@@ -583,7 +589,8 @@ class DocumentDetailApi(DocumentResource):
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
'display_status': document.display_status,
|
||||
'doc_form': document.doc_form
|
||||
}
|
||||
else:
|
||||
process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@@ -614,7 +621,8 @@ class DocumentDetailApi(DocumentResource):
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
'display_status': document.display_status,
|
||||
'doc_form': document.doc_form
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.add_segment_to_index_task import add_segment_to_index_task
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
|
||||
|
||||
segment_fields = {
|
||||
@@ -24,6 +24,7 @@ segment_fields = {
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form,
|
||||
'has_more': has_more,
|
||||
'limit': limit,
|
||||
'total': total
|
||||
@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_segment_to_index_task.delay(segment.id)
|
||||
enable_segment_to_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "disable":
|
||||
@@ -202,7 +204,91 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
raise InvalidActionError()
|
||||
|
||||
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetDocumentSegmentApi,
|
||||
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
|
||||
api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
|
||||
@@ -28,6 +28,7 @@ segment_fields = {
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
@@ -95,8 +96,8 @@ class HitTestingApi(Resource):
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import InstalledApp
|
||||
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -27,16 +31,17 @@ class AppParameterApi(InstalledAppResource):
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, installed_app):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
66
api/controllers/console/universal_chat/audio.py
Normal file
66
api/controllers/console/universal_chat/audio.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class UniversalChatAudioApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
|
||||
142
api/controllers/console/universal_chat/chat.py
Normal file
142
api/controllers/console/universal_chat/chat.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class UniversalChatApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
|
||||
if not llm_constant.models[args['model']]:
|
||||
raise ValueError("Model not exists.")
|
||||
|
||||
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
if not args['model_config']['agent_mode']['tools']:
|
||||
args['model_config']['agent_mode']['tools'] = [
|
||||
{
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
args['model_config']['agent_mode']['tools'].append({
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
})
|
||||
|
||||
args['inputs'] = {}
|
||||
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
PubHandler.stop(current_user, task_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
||||
118
api/controllers/console/universal_chat/conversation.py
Normal file
118
api/controllers/console/universal_chat/conversation.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationApi(UniversalChatResource):
|
||||
def delete(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
||||
127
api/controllers/console/universal_chat/message.py
Normal file
127
api/controllers/console/universal_chat/message.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound, InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class UniversalChatMessageListApi(UniversalChatResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||
def post(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
def get(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
||||
36
api/controllers/console/universal_chat/parameter.py
Normal file
36
api/controllers/console/universal_chat/parameter.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Resource for app variables."""
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, universal_app: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
||||
84
api/controllers/console/universal_chat/wraps.py
Normal file
84
api/controllers/console/universal_chat/wraps.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
def universal_chat_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# get universal chat app
|
||||
universal_app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == True
|
||||
).first()
|
||||
|
||||
if universal_app is None:
|
||||
# create universal app if not exists
|
||||
universal_app = App(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name='Universal Chat',
|
||||
mode='chat',
|
||||
is_universal=True,
|
||||
icon='',
|
||||
icon_background='',
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
status='normal'
|
||||
)
|
||||
|
||||
db.session.add(universal_app)
|
||||
db.session.flush()
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement='',
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-16k",
|
||||
"completion_params": {
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([]),
|
||||
pre_prompt='',
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||
)
|
||||
|
||||
app_model_config.app_id = universal_app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
universal_app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return view(universal_app, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class UniversalChatResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, abort
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
|
||||
ProviderService.init_supported_provider(current_user.current_tenant)
|
||||
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
||||
|
||||
provider_list = [
|
||||
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
|
||||
'quota_used': p.quota_used
|
||||
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
||||
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
||||
ProviderName(p.provider_name))
|
||||
ProviderName(p.provider_name), only_custom=True)
|
||||
if p.provider_type == ProviderType.CUSTOM.value else None
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
|
||||
is_valid=token_is_valid)
|
||||
db.session.add(provider_model)
|
||||
|
||||
if provider_model.is_valid:
|
||||
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
|
||||
other_providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
|
||||
Provider.provider_name != provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
||||
|
||||
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: remove this when the provider is supported
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||
|
||||
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
|
||||
provider_model.is_valid = args['is_enabled']
|
||||
db.session.commit()
|
||||
elif not provider_model:
|
||||
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
|
||||
else:
|
||||
quota_limit = 0
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
provider,
|
||||
quota_limit,
|
||||
args['is_enabled']
|
||||
)
|
||||
else:
|
||||
abort(403)
|
||||
|
||||
136
api/controllers/console/workspace/tool_providers.py
Normal file
136
api/controllers/console/workspace/tool_providers.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||
from extensions.ext_database import db
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_credential_dict = {}
|
||||
for tool_name in ToolProviderName:
|
||||
tool_credential_dict[tool_name.value] = {
|
||||
'tool_name': tool_name.value,
|
||||
'is_enabled': False,
|
||||
'credentials': None
|
||||
}
|
||||
|
||||
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for p in tool_providers:
|
||||
if p.is_enabled:
|
||||
tool_credential_dict[p.tool_name] = {
|
||||
'tool_name': p.tool_name,
|
||||
'is_enabled': p.is_enabled,
|
||||
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||
}
|
||||
|
||||
return list(tool_credential_dict.values())
|
||||
|
||||
|
||||
class ToolProviderCredentialsApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||
f'current_role is {current_user.current_tenant.current_role}')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == tenant.id,
|
||||
ToolProvider.tool_name == provider,
|
||||
).first()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if tool_provider_model:
|
||||
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||
tool_provider_model.is_enabled = True
|
||||
else:
|
||||
tool_provider_model = ToolProvider(
|
||||
tenant_id=tenant.id,
|
||||
tool_name=provider,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
is_enabled=True
|
||||
)
|
||||
db.session.add(tool_provider_model)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ToolProviderCredentialsValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
||||
@@ -4,6 +4,10 @@ from flask_restful import fields, marshal_with
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -28,15 +32,16 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
|
||||
dataset_process_rule=dataset.latest_process_rule,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -27,15 +31,16 @@ class AppParameterApi(WebApiResource):
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -62,7 +62,10 @@ class ConversationApi(WebApiResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -38,6 +38,8 @@ def decode_jwt_token():
|
||||
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if app_model.enable_site is False:
|
||||
raise Unauthorized('Site is disabled.')
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedAnthropicCredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
anthropic: Optional[HostedAnthropicCredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
@@ -26,3 +31,6 @@ def init_app(app: Flask):
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
|
||||
if app.config.get("ANTHROPIC_API_KEY"):
|
||||
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
35
api/core/agent/agent/calc_token_mixin.py
Normal file
35
api/core/agent/agent/calc_token_mixin.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain import OpenAI
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
return llm.get_num_tokens_from_messages(messages)
|
||||
|
||||
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param llm:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
completion_max_tokens = llm.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
|
||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
||||
83
api/core/agent/agent/multi_dataset_router_agent.py
Normal file
83
api/core/agent/agent/multi_dataset_router_agent.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
**kwargs,
|
||||
)
|
||||
112
api/core/agent/agent/openai_function_call.py
Normal file
112
api/core/agent/agent/openai_function_call.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
_format_intermediate_steps
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
132
api/core/agent/agent/openai_function_call_summarize_mixin.py
Normal file
132
api/core/agent/agent/openai_function_call_summarize_mixin.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
|
||||
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
102
api/core/agent/agent/openai_multi_function_call.py
Normal file
102
api/core/agent/agent/openai_multi_function_call.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||
_parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseMultiActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
29
api/core/agent/agent/output_parser/structured_chat.py
Normal file
29
api/core/agent/agent/output_parser/structured_chat.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
|
||||
logger
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
try:
|
||||
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL)
|
||||
if action_match is not None:
|
||||
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||
if isinstance(response, list):
|
||||
# gpt turbo frequently ignores the directive to emit a single action
|
||||
logger.warning("Got multiple action responses: %s", response)
|
||||
response = response[0]
|
||||
if response["action"] == "Final Answer":
|
||||
return AgentFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
return AgentAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
187
api/core/agent/agent/structured_chat.py
Normal file
187
api/core/agent/agent/structured_chat.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,86 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
|
||||
class AgentBuilder:
|
||||
@classmethod
|
||||
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=agent_loop_gather_callback_handler.model_name,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
tool.callbacks = [
|
||||
agent_loop_gather_callback_handler,
|
||||
dataset_tool_callback_handler,
|
||||
DifyStdOutCallbackHandler()
|
||||
]
|
||||
|
||||
prompt = cls.build_agent_prompt_template(
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
agent_llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
|
||||
|
||||
agent_callback_manager = CallbackManager(
|
||||
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_chain = AgentExecutor.from_agent_and_tools(
|
||||
tools=tools,
|
||||
agent=agent,
|
||||
memory=memory,
|
||||
callbacks=agent_callback_manager,
|
||||
max_iterations=6,
|
||||
early_stopping_method="generate",
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
)
|
||||
|
||||
return agent_chain
|
||||
|
||||
@classmethod
|
||||
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
prompt = ConversationalAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
agent = ConversationalAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
else:
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
|
||||
return agent
|
||||
122
api/core/agent/agent_executor.py
Normal file
122
api/core/agent/agent_executor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class PlanningStrategy(str, enum.Enum):
|
||||
ROUTER = 'router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
llm: BaseLanguageModel
|
||||
tools: list[BaseTool]
|
||||
summary_llm: BaseLanguageModel
|
||||
dataset_llm: BaseLanguageModel
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentExecuteResult(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
output: Optional[str]
|
||||
configuration: AgentConfiguration
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
def __init__(self, configuration: AgentConfiguration):
|
||||
self.configuration = configuration
|
||||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
llm=self.configuration.dataset_llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
||||
|
||||
return agent
|
||||
|
||||
def should_use_agent(self, query: str) -> bool:
|
||||
return self.agent.should_use_agent(query)
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
memory=self.configuration.memory,
|
||||
max_iterations=self.configuration.max_iterations,
|
||||
max_execution_time=self.configuration.max_execution_time,
|
||||
early_stopping_method=self.configuration.early_stopping_method,
|
||||
callbacks=self.configuration.callbacks
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except Exception:
|
||||
logging.exception("agent_executor run failed")
|
||||
output = None
|
||||
|
||||
return AgentExecuteResult(
|
||||
output=output,
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
def clear_agent_loops(self) -> None:
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
@@ -61,9 +65,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
if 'function_call' in completion_message.additional_kwargs:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
||||
else:
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
else:
|
||||
self._current_loop.completion = completion_generation.text
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
@@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@@ -89,15 +106,29 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
tool_input = json.dumps({"query": action.tool_input}
|
||||
if isinstance(action.tool_input, str) else action.tool_input)
|
||||
completion = None
|
||||
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
|
||||
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
|
||||
thought = action.log.strip()
|
||||
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
|
||||
else:
|
||||
action_name_position = action.log.index("Action:") if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'llm_end':
|
||||
self._current_loop.status = 'agent_action'
|
||||
self._current_loop.thought = thought
|
||||
self._current_loop.tool_name = tool
|
||||
self._current_loop.tool_input = tool_input
|
||||
if completion is not None:
|
||||
self._current_loop.completion = completion
|
||||
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
@@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
@@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
@@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
self._current_loop.thought = '[DONE]'
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
@@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name = serialized.get('name')
|
||||
dataset_id = tool_name[len("dataset-"):]
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
|
||||
# tool_name = serialized.get('name')
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
dataset_id = input_dict.get('dataset_id')
|
||||
query = input_dict.get('query')
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
|
||||
@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
|
||||
tool_output: str = None
|
||||
|
||||
prompt: str = None
|
||||
prompt_tokens: int = None
|
||||
prompt_tokens: int = 0
|
||||
completion: str = None
|
||||
completion_tokens: int = None
|
||||
completion_tokens: int = 0
|
||||
|
||||
latency: float = None
|
||||
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def __init__(self, llm: BaseLanguageModel,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.llm = llm
|
||||
self.llm_message = LLMMessage()
|
||||
@@ -48,7 +46,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
@@ -69,9 +67,8 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
|
||||
@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
|
||||
llm_constant.agent_model_name,
|
||||
conversation_message_task
|
||||
)
|
||||
self.agent_callback = None
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.agent_loop_gather_callback_handler.current_chain = None
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.chain.tool_chain import ToolChain
|
||||
|
||||
|
||||
class ChainBuilder:
|
||||
@classmethod
|
||||
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
|
||||
return ToolChain(
|
||||
tool=tool,
|
||||
input_key=kwargs.get('input_key', 'input'),
|
||||
output_key=kwargs.get('output_key', 'tool_output'),
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
|
||||
SensitiveWordAvoidanceChain]:
|
||||
sensitive_words = tool_config.get("words", "")
|
||||
if tool_config.get("enabled", False) \
|
||||
and sensitive_words:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=tool_config.get("canned_response", ''),
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=[DifyStdOutCallbackHandler()],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Base classes for LLM-powered router chains."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
|
||||
class Route(NamedTuple):
|
||||
destination: Optional[str]
|
||||
next_inputs: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMRouterChain(Chain):
|
||||
"""A router chain that uses an LLM chain to perform routing."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM chain used to perform routing"""
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt(cls, values: dict) -> dict:
|
||||
prompt = values["llm_chain"].prompt
|
||||
if prompt.output_parser is None:
|
||||
raise ValueError(
|
||||
"LLMRouterChain requires base llm_chain prompt to have an output"
|
||||
" parser that converts LLM text output to a dictionary with keys"
|
||||
" 'destination' and 'next_inputs'. Received a prompt with no output"
|
||||
" parser."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.llm_chain.input_keys
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
super()._validate_outputs(outputs)
|
||||
if not isinstance(outputs["next_inputs"], dict):
|
||||
raise ValueError
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
output = cast(
|
||||
Dict[str, Any],
|
||||
self.llm_chain.predict_and_parse(**inputs),
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
||||
) -> LLMRouterChain:
|
||||
"""Convenience constructor."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["destination", "next_inputs"]
|
||||
|
||||
def route(self, inputs: Dict[str, Any]) -> Route:
|
||||
result = self(inputs)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
|
||||
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
||||
"""Parser for output of router chain int he multi-prompt chain."""
|
||||
|
||||
default_destination: str = "DEFAULT"
|
||||
next_inputs_type: Type = str
|
||||
next_inputs_inner_key: str = "input"
|
||||
|
||||
def parse(self, text: str) -> Dict[str, Any]:
|
||||
try:
|
||||
expected_keys = ["destination", "next_inputs"]
|
||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||
if not isinstance(parsed["destination"], str):
|
||||
raise ValueError("Expected 'destination' to be a string.")
|
||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||
raise ValueError(
|
||||
f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
||||
)
|
||||
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
||||
if (
|
||||
parsed["destination"].strip().lower()
|
||||
== self.default_destination.lower()
|
||||
):
|
||||
parsed["destination"] = None
|
||||
else:
|
||||
parsed["destination"] = parsed["destination"].strip()
|
||||
return parsed
|
||||
except Exception as e:
|
||||
raise OutputParserException(
|
||||
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
|
||||
)
|
||||
@@ -1,110 +0,0 @@
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from langchain.chains import SequentialChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.chain_builder import ChainBuilder
|
||||
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MainChainBuilder:
|
||||
@classmethod
|
||||
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
first_input_key = "input"
|
||||
final_output_key = "output"
|
||||
|
||||
chains = []
|
||||
|
||||
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# agent mode
|
||||
tool_chains, chains_output_key = cls.get_agent_chains(
|
||||
tenant_id=tenant_id,
|
||||
agent_mode=agent_mode,
|
||||
rest_tokens=rest_tokens,
|
||||
memory=memory,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
chains += tool_chains
|
||||
|
||||
if chains_output_key:
|
||||
final_output_key = chains_output_key
|
||||
|
||||
if len(chains) == 0:
|
||||
return None
|
||||
|
||||
for chain in chains:
|
||||
chain = cast(Chain, chain)
|
||||
chain.callbacks.append(chain_callback_handler)
|
||||
|
||||
# build main chain
|
||||
overall_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=[first_input_key],
|
||||
output_variables=[final_output_key],
|
||||
memory=memory, # only for use the memory prompt input key
|
||||
)
|
||||
|
||||
return overall_chain
|
||||
|
||||
@classmethod
|
||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
|
||||
rest_tokens: int,
|
||||
memory: Optional[BaseChatMemory],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
# agent mode
|
||||
chains = []
|
||||
if agent_mode and agent_mode.get('enabled'):
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
pre_fixed_chains = []
|
||||
# agent_tools = []
|
||||
datasets = []
|
||||
for tool in tools:
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
if tool_type == 'sensitive-word-avoidance':
|
||||
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
|
||||
if chain:
|
||||
pre_fixed_chains.append(chain)
|
||||
elif tool_type == "dataset":
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
if dataset:
|
||||
datasets.append(dataset)
|
||||
|
||||
# add pre-fixed chains
|
||||
chains += pre_fixed_chains
|
||||
|
||||
if len(datasets) > 0:
|
||||
# tool to chain
|
||||
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
|
||||
tenant_id=tenant_id,
|
||||
datasets=datasets,
|
||||
conversation_message_task=conversation_message_task,
|
||||
rest_tokens=rest_tokens,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
chains.append(multi_dataset_router_chain)
|
||||
|
||||
final_output_key = cls.get_chains_output_key(chains)
|
||||
|
||||
return chains, final_output_key
|
||||
|
||||
@classmethod
|
||||
def get_chains_output_key(cls, chains: List[Chain]):
|
||||
if len(chains) > 0:
|
||||
return chains[-1].output_keys[0]
|
||||
return None
|
||||
@@ -1,198 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import Mapping, List, Dict, Any, Optional
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import Extra
|
||||
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.tool.dataset_index_tool import DatasetTool
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
DEFAULT_K = 2
|
||||
CONTEXT_TOKENS_PERCENT = 0.3
|
||||
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
||||
Given a raw text input to a language model select the model prompt best suited for \
|
||||
the input. You will be given the names of the available prompts and a description of \
|
||||
what the prompt is best suited for. You may also revise the original input if you \
|
||||
think that revising it will ultimately lead to a better response from the language \
|
||||
model.
|
||||
|
||||
<< FORMATTING >>
|
||||
Return a markdown code snippet with a JSON object formatted to look like, \
|
||||
no any other string out of markdown code snippet:
|
||||
```json
|
||||
{{{{
|
||||
"destination": string \\ name of the prompt to use or "DEFAULT"
|
||||
"next_inputs": string \\ a potentially modified version of the original input
|
||||
}}}}
|
||||
```
|
||||
|
||||
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
||||
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
||||
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
||||
modifications are needed.
|
||||
|
||||
<< CANDIDATE PROMPTS >>
|
||||
{destinations}
|
||||
|
||||
<< INPUT >>
|
||||
{{input}}
|
||||
|
||||
<< OUTPUT >>
|
||||
"""
|
||||
|
||||
|
||||
class MultiDatasetRouterChain(Chain):
|
||||
"""Use a single chain to route an input to one of multiple candidate chains."""
|
||||
|
||||
router_chain: LLMRouterChain
|
||||
"""Chain for deciding a destination chain and the input to it."""
|
||||
dataset_tools: Mapping[str, DatasetTool]
|
||||
"""Map of name to candidate chains that inputs can be routed to."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the router chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.router_chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["text"]
|
||||
|
||||
@classmethod
|
||||
def from_datasets(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
datasets: List[Dataset],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Convenience constructor for instantiating from destination prompts."""
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
|
||||
else ('useful for when you want to answer queries about the ' + d.name))
|
||||
for d in datasets]
|
||||
destinations_str = "\n".join(destinations)
|
||||
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
||||
destinations=destinations_str
|
||||
)
|
||||
|
||||
router_prompt = PromptTemplate(
|
||||
template=router_template,
|
||||
input_variables=["input"],
|
||||
output_parser=RouterOutputParser(),
|
||||
)
|
||||
|
||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||
dataset_tools = {}
|
||||
for dataset in datasets:
|
||||
# fulfill description when it is empty
|
||||
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
||||
continue
|
||||
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||
if k == 0:
|
||||
continue
|
||||
|
||||
dataset_tool = DatasetTool(
|
||||
name=f"dataset-{dataset.id}",
|
||||
description=description,
|
||||
k=k,
|
||||
dataset=dataset,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
dataset_tools[str(dataset.id)] = dataset_tool
|
||||
|
||||
return cls(
|
||||
router_chain=router_chain,
|
||||
dataset_tools=dataset_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||
processing_rule = dataset.latest_process_rule
|
||||
if not processing_rule:
|
||||
return DEFAULT_K
|
||||
|
||||
if processing_rule.mode == "custom":
|
||||
rules = processing_rule.rules_dict
|
||||
if not rules:
|
||||
return DEFAULT_K
|
||||
|
||||
segmentation = rules["segmentation"]
|
||||
segment_max_tokens = segmentation["max_tokens"]
|
||||
else:
|
||||
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||
|
||||
# when rest_tokens is less than default context tokens
|
||||
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||
return rest_tokens // segment_max_tokens
|
||||
|
||||
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||
|
||||
# when context_limit_tokens is less than default context tokens, use default_k
|
||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||
return DEFAULT_K
|
||||
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||
return context_limit_tokens // segment_max_tokens
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if len(self.dataset_tools) == 0:
|
||||
return {"text": ''}
|
||||
elif len(self.dataset_tools) == 1:
|
||||
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
|
||||
|
||||
route = self.router_chain.route(inputs)
|
||||
|
||||
destination = ''
|
||||
if route.destination:
|
||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||
match = re.search(pattern, route.destination, re.IGNORECASE)
|
||||
if match:
|
||||
destination = match.group()
|
||||
|
||||
if not destination:
|
||||
return {"text": ''}
|
||||
elif destination in self.dataset_tools:
|
||||
return {"text": self.dataset_tools[destination].run(
|
||||
route.next_inputs['input']
|
||||
)}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Received invalid destination chain name '{destination}'"
|
||||
)
|
||||
@@ -1,51 +0,0 @@
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ToolChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
tool: BaseTool
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "tool_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
input = inputs[self.input_key]
|
||||
output = self.tool.run(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
input = inputs[self.input_key]
|
||||
output = await self.tool.arun(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@@ -8,30 +9,31 @@ from langchain.llms import BaseLLM
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.constant import llm_constant
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||
DifyStdOutCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError
|
||||
from core.llm.fake import FakeLLM
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.chain.main_chain_builder import MainChainBuilder
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBStringBufferSharedMemory
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
@@ -69,18 +71,33 @@ class Completion:
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# build main chain include agent
|
||||
main_chain = MainChainBuilder.to_langchain_components(
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
tenant_id=app.tenant_id,
|
||||
agent_mode=app_model_config.agent_mode_dict,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
||||
conversation_message_task=conversation_message_task
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
chain_output = ''
|
||||
if main_chain:
|
||||
chain_output = main_chain.run(query)
|
||||
# parse sensitive_word_avoidance_chain
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
|
||||
# run the final llm
|
||||
try:
|
||||
@@ -90,7 +107,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
streaming=streaming
|
||||
@@ -105,9 +122,20 @@ class Completion:
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
chain_output: str,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
final_llm = FakeLLM(response=agent_execute_result.output,
|
||||
origin_llm=agent_execute_result.configuration.llm,
|
||||
streaming=streaming)
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
response = final_llm.generate([[HumanMessage(content=query)]])
|
||||
return response
|
||||
|
||||
final_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
@@ -118,10 +146,11 @@ class Completion:
|
||||
prompt, stop_words = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
@@ -129,6 +158,7 @@ class Completion:
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode=mode
|
||||
)
|
||||
@@ -138,41 +168,31 @@ class Completion:
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
# disable template string in query
|
||||
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
|
||||
# if query_params:
|
||||
# for query_param in query_params:
|
||||
# if query_param not in inputs:
|
||||
# inputs[query_param] = '{{' + query_param + '}}'
|
||||
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following CONTEXT as your learned knowledge:
|
||||
[CONTEXT]
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if chain_output else "")
|
||||
""" if agent_execute_result else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if chain_output:
|
||||
inputs['context'] = chain_output
|
||||
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
|
||||
# if context_params:
|
||||
# for context_param in context_params:
|
||||
# if context_param not in inputs:
|
||||
# inputs[context_param] = '{{' + context_param + '}}'
|
||||
if agent_execute_result:
|
||||
inputs['context'] = agent_execute_result.output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
@@ -202,12 +222,13 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||
[CONTEXT]
|
||||
if agent_execute_result:
|
||||
human_inputs['context'] = agent_execute_result.output
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
@@ -219,7 +240,7 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\nHuman: {{query}}\nAI: "
|
||||
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
@@ -228,20 +249,17 @@ And answer according to the language of the user's question.
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
||||
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
||||
- memory.llm.max_tokens - curr_message_tokens
|
||||
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
||||
model_name = model['name']
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
||||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
# disable template string in query
|
||||
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
|
||||
# if histories_params:
|
||||
# for histories_param in histories_params:
|
||||
# if histories_param not in human_inputs:
|
||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||
|
||||
human_message_prompt += "\n\n" + histories
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
||||
human_message_prompt += histories + "\n</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
@@ -253,10 +271,13 @@ And answer according to the language of the user's question.
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
return messages, ['\nHuman:']
|
||||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return messages, ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
@@ -267,8 +288,7 @@ And answer according to the language of the user's question.
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> \
|
||||
str:
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
@@ -307,17 +327,19 @@ And answer according to the language of the user's question.
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
|
||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
max_tokens = llm.max_tokens
|
||||
model_name = app_model_config.model_dict.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt, _ = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
@@ -332,16 +354,17 @@ And answer according to the language of the user's question.
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
|
||||
max_tokens = final_llm.max_tokens
|
||||
model_name = model.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
|
||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||
else:
|
||||
prompt_tokens = final_llm.get_messages_tokens(prompt)
|
||||
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
@@ -350,9 +373,10 @@ And answer according to the language of the user's question.
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
model=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -360,10 +384,12 @@ And answer according to the language of the user's question.
|
||||
original_prompt, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
@@ -390,6 +416,7 @@ And answer according to the language of the user's question.
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode='completion'
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||
'claude-2': 'anthropic', # 100,000 tokens
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
@@ -10,10 +12,13 @@ models = {
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
|
||||
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||
'whisper-1': 'openai'
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
@@ -23,17 +28,21 @@ max_context_token_length = {
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191
|
||||
'text-embedding-ada-002': 8191,
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
@@ -52,6 +61,14 @@ models_by_mode = {
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': Decimal('0.00163'),
|
||||
'completion': Decimal('0.00551'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': Decimal('0.01102'),
|
||||
'completion': Decimal('0.03268'),
|
||||
},
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
|
||||
@@ -52,11 +52,11 @@ class ConversationMessageTask:
|
||||
message=self.message,
|
||||
conversation=self.conversation,
|
||||
chain_pub=False, # disabled currently
|
||||
agent_thought_pub=False # disabled currently
|
||||
agent_thought_pub=True
|
||||
)
|
||||
|
||||
def init(self):
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
|
||||
self.model_dict['provider'] = provider_name
|
||||
|
||||
override_model_configs = None
|
||||
@@ -69,6 +69,7 @@ class ConversationMessageTask:
|
||||
"suggested_questions": self.app_model_config.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
|
||||
"more_like_this": self.app_model_config.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
|
||||
"user_input_form": self.app_model_config.user_input_form_list,
|
||||
}
|
||||
|
||||
@@ -89,7 +90,7 @@ class ConversationMessageTask:
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||
system_instruction_tokens = llm.get_messages_tokens([system_message])
|
||||
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
@@ -185,6 +186,7 @@ class ConversationMessageTask:
|
||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.app.tenant_id,
|
||||
Provider.provider_name == provider.provider_name,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + 1})
|
||||
|
||||
@@ -206,7 +208,28 @@ class ConversationMessageTask:
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
|
||||
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
answer=agent_loop.completion,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
|
||||
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
|
||||
@@ -221,34 +244,18 @@ class ConversationMessageTask:
|
||||
agent_answer_unit_price
|
||||
)
|
||||
|
||||
message_agent_loop = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
observation=agent_loop.tool_output,
|
||||
tool_process_data='', # currently not support
|
||||
message=agent_loop.prompt,
|
||||
message_token=loop_message_tokens,
|
||||
message_unit_price=agent_message_unit_price,
|
||||
answer=agent_loop.completion,
|
||||
answer_token=loop_answer_tokens,
|
||||
answer_unit_price=agent_answer_unit_price,
|
||||
latency=agent_loop.latency,
|
||||
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
|
||||
total_price=loop_total_price,
|
||||
currency=llm_constant.model_currency,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_loop)
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = llm_constant.model_currency
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_loop)
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_query_obj.dataset_id,
|
||||
@@ -345,16 +352,14 @@ class PubHandler:
|
||||
content = {
|
||||
'event': 'agent_thought',
|
||||
'data': {
|
||||
'id': message_agent_thought.id,
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_agent_thought.message_chain_id,
|
||||
'agent_thought_id': message_agent_thought.id,
|
||||
'position': message_agent_thought.position,
|
||||
'thought': message_agent_thought.thought,
|
||||
'tool': message_agent_thought.tool,
|
||||
'tool_input': message_agent_thought.tool_input,
|
||||
'observation': message_agent_thought.observation,
|
||||
'answer': message_agent_thought.answer,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
@@ -387,6 +392,15 @@ class PubHandler:
|
||||
def _is_stopped(self):
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def ping(cls, user: Union[Account | EndUser], task_id: str):
|
||||
content = {
|
||||
'event': 'ping'
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account | EndUser], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
@@ -22,22 +26,41 @@ class FileExtractor:
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
return cls.load_from_file(file_path, return_text, upload_file)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
|
||||
@@ -39,7 +39,7 @@ class ExcelLoader(BaseLoader):
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
|
||||
document = Document(page_content=item)
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
return data
|
||||
|
||||
@@ -68,7 +68,7 @@ class DatesetDocumentStore:
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document == self._document_id
|
||||
DocumentSegment.document_id == self._document_id
|
||||
).scalar()
|
||||
|
||||
if max_position is None:
|
||||
@@ -105,9 +105,14 @@ class DatesetDocumentStore:
|
||||
tokens=tokens,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
if 'answer' in doc.metadata and doc.metadata['answer']:
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
|
||||
db.session.add(segment_document)
|
||||
else:
|
||||
segment_document.content = doc.page_content
|
||||
if 'answer' in doc.metadata and doc.metadata['answer']:
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
segment_document.index_node_hash = doc.metadata['doc_hash']
|
||||
segment_document.word_count = len(doc.page_content)
|
||||
segment_document.tokens = tokens
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
|
||||
text_embeddings.extend(embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
@handle_openai_exceptions
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage, OutputParserException, BaseMessage
|
||||
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
@@ -12,8 +12,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
|
||||
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
|
||||
GENERATOR_QA_PROMPT
|
||||
|
||||
# gpt-3.5-turbo works not well
|
||||
generate_base_model = 'text-davinci-003'
|
||||
@@ -23,11 +23,16 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
prompt = prompt.format(query=query)
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=50
|
||||
max_tokens=50,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
@@ -52,7 +57,17 @@ class LLMGenerator:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
|
||||
if len(message.query) > 2000:
|
||||
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||
else:
|
||||
query = message.query
|
||||
|
||||
if len(message.answer) > 2000:
|
||||
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||
else:
|
||||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
@@ -171,3 +186,27 @@ class LLMGenerator:
|
||||
}
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@@ -17,7 +17,7 @@ class IndexBuilder:
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -205,6 +205,16 @@ class KeywordTableIndex(BaseIndex):
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
index: KeywordTableIndex
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
import concurrent
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
@@ -70,7 +70,13 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# new_documents = []
|
||||
# for document in documents:
|
||||
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
|
||||
# document_qa_list = self.format_split_text(response)
|
||||
# for result in document_qa_list:
|
||||
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
|
||||
# new_documents.append(document)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
@@ -91,6 +97,22 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = []
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
try:
|
||||
@@ -205,7 +227,8 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
@@ -225,7 +248,7 @@ class IndexingRunner:
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents(
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
@@ -237,7 +260,25 @@ class IndexingRunner:
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
|
||||
self.filter_string(document.page_content))
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
@@ -246,7 +287,7 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict:
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
@@ -285,7 +326,7 @@ class IndexingRunner:
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents(
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=documents,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
@@ -296,7 +337,25 @@ class IndexingRunner:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
@@ -391,7 +450,9 @@ class IndexingRunner:
|
||||
documents = self._split_to_documents(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
processing_rule=processing_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_form=dataset_document.doc_form
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
@@ -428,7 +489,74 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
all_documents = []
|
||||
all_qa_documents = []
|
||||
for text_doc in text_docs:
|
||||
# document clean
|
||||
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
||||
text_doc.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
documents = splitter.split_documents([text_doc])
|
||||
split_documents = []
|
||||
for document_node in documents:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
# processing qa document
|
||||
if document_form == 'qa_model':
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
threads = []
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
|
||||
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
return all_qa_documents
|
||||
return all_documents
|
||||
|
||||
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@@ -445,7 +573,6 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if document.page_content is None or not document.page_content.strip():
|
||||
continue
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
||||
@@ -487,6 +614,23 @@ class IndexingRunner:
|
||||
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
|
||||
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
|
||||
|
||||
result = [] # 存储最终的结果
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
# 如果Q和A都存在,就将其添加到结果中
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
|
||||
@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
|
||||
59
api/core/llm/fake.py
Normal file
59
api/core/llm/fake.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
from typing import List, Optional, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
|
||||
|
||||
|
||||
class FakeLLM(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
response: str
|
||||
origin_llm: Optional[BaseLanguageModel] = None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
return self.response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"response": self.response}
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if self.streaming:
|
||||
for token in output_str:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
time.sleep(0.01)
|
||||
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
llm_output = {"token_usage": {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
}}
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from models.provider import ProviderType
|
||||
from models.provider import ProviderType, ProviderName
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
@@ -32,43 +33,43 @@ class LLMBuilder:
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id)
|
||||
provider = cls.get_default_provider(tenant_id, model_name)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
llm_cls = None
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == 'openai':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableChatOpenAI
|
||||
else:
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
llm_cls = StreamableChatAnthropic
|
||||
elif mode == 'completion':
|
||||
if provider == 'openai':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
else:
|
||||
|
||||
if not llm_cls:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
|
||||
model_kwargs = {
|
||||
'model_name': model_name,
|
||||
'temperature': kwargs.get('temperature', 0),
|
||||
'max_tokens': kwargs.get('max_tokens', 256),
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
'callbacks': kwargs.get('callbacks', None),
|
||||
'streaming': kwargs.get('streaming', False),
|
||||
}
|
||||
|
||||
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
||||
model_kwargs.update(model_credentials)
|
||||
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
**model_extras_kwargs,
|
||||
callbacks=kwargs.get('callbacks', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
)
|
||||
return llm_cls(**model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
@@ -118,14 +119,30 @@ class LLMBuilder:
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_default_provider(cls, tenant_id: str) -> str:
|
||||
provider = BaseProvider.get_valid_provider(tenant_id)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
|
||||
provider_name = llm_constant.models[model_name]
|
||||
|
||||
if provider_name == 'openai':
|
||||
# get the default provider (openai / azure_openai) for the tenant
|
||||
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = azure_openai_provider
|
||||
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {provider_name} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider_name = 'openai'
|
||||
else:
|
||||
provider_name = provider.provider_name
|
||||
|
||||
return provider_name
|
||||
|
||||
@@ -1,23 +1,138 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import anthropic
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName, ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
},
|
||||
]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
||||
"""
|
||||
return {
|
||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
return self.get_provider_api_key(model_id=model_id)
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
return ProviderName.ANTHROPIC
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('anthropic_api_key'):
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# check OpenAI / Azure OpenAI credential is valid
|
||||
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
if quota_used >= quota_limit:
|
||||
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
|
||||
f"please configure OpenAI or Azure OpenAI provider first.")
|
||||
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'anthropic_api_key' not in config:
|
||||
raise ValueError('anthropic_api_key must be provided.')
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
anthropic_api_key=config['anthropic_api_key'],
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
|
||||
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
@@ -9,32 +10,42 @@ from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id) if not credentials else credentials
|
||||
url = "{}/openai/deployments?api-version={}".format(
|
||||
str(credentials.get('openai_api_base')),
|
||||
str(credentials.get('openai_api_version'))
|
||||
)
|
||||
return []
|
||||
|
||||
headers = {
|
||||
"api-key": str(credentials.get('openai_api_key')),
|
||||
"content-type": "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return [{
|
||||
'id': deployment['id'],
|
||||
'name': '{} ({})'.format(deployment['id'], deployment['model'])
|
||||
} for deployment in result['data'] if deployment['status'] == 'succeeded']
|
||||
else:
|
||||
if response.status_code == 401:
|
||||
raise AzureAuthenticationError()
|
||||
def check_embedding_model(self, credentials: Optional[dict] = None):
|
||||
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
|
||||
try:
|
||||
result = openai.Embedding.create(input=['test'],
|
||||
engine='text-embedding-ada-002',
|
||||
timeout=60,
|
||||
api_key=str(credentials.get('openai_api_key')),
|
||||
api_base=str(credentials.get('openai_api_base')),
|
||||
api_type='azure',
|
||||
api_version=str(credentials.get('openai_api_version')))["data"][0][
|
||||
"embedding"]
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise AzureAuthenticationError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
|
||||
except openai.error.InvalidRequestError as e:
|
||||
if e.http_status == 404:
|
||||
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
|
||||
"deployment name is exists in Azure AI")
|
||||
else:
|
||||
raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
||||
|
||||
if not isinstance(result, list):
|
||||
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
@@ -42,9 +53,10 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
if model_id == 'text-embedding-ada-002':
|
||||
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||
config['chunk_size'] = 1
|
||||
config['chunk_size'] = 16
|
||||
else:
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
@@ -52,16 +64,16 @@ class AzureProvider(BaseProvider):
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
@@ -70,7 +82,7 @@ class AzureProvider(BaseProvider):
|
||||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
@@ -81,7 +93,6 @@ class AzureProvider(BaseProvider):
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
@@ -93,34 +104,13 @@ class AzureProvider(BaseProvider):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'openai_api_version' not in config:
|
||||
config['openai_api_version'] = '2023-03-15-preview'
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
|
||||
models = self.get_models(credentials=config)
|
||||
|
||||
if not models:
|
||||
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
|
||||
"'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
|
||||
"and 'gpt-4', 'gpt-35-turbo-16k' (optional).")
|
||||
|
||||
fixed_model_ids = [
|
||||
'text-davinci-003',
|
||||
'gpt-35-turbo',
|
||||
'text-embedding-ada-002'
|
||||
]
|
||||
|
||||
current_model_ids = [model['id'] for model in models]
|
||||
|
||||
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
|
||||
fixed_model_id not in current_model_ids]
|
||||
|
||||
if missing_model_ids:
|
||||
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
|
||||
self.check_embedding_model(credentials=config)
|
||||
except ValidateFailedError as e:
|
||||
raise e
|
||||
except AzureAuthenticationError:
|
||||
raise ValidateFailedError('Validation failed, please check your API Key.')
|
||||
except (requests.ConnectionError, requests.RequestException):
|
||||
raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
|
||||
except AzureRequestFailedError as ex:
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
except Exception as ex:
|
||||
@@ -133,7 +123,7 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(prefer_custom)
|
||||
provider = self.get_provider(only_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
||||
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
|
||||
|
||||
@classmethod
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
|
||||
Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
If both CUSTOM and System providers exist.
|
||||
"""
|
||||
query = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id
|
||||
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
|
||||
if provider_name:
|
||||
query = query.filter(Provider.provider_name == provider_name)
|
||||
|
||||
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
if only_custom:
|
||||
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
providers = query.order_by(Provider.provider_type.asc()).all()
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
return provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
return provider
|
||||
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
elif system_provider:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> str:
|
||||
if self.get_provider_name() != ProviderName.OPENAI:
|
||||
raise ProviderTokenNotInitError()
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = ''
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ class LLMProviderService:
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated)
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
|
||||
|
||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
return self.provider.get_provider(prefer_custom)
|
||||
def get_provider_db_record(self) -> Optional[Provider]:
|
||||
return self.provider.get_provider()
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional, Union
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
|
||||
from langchain.chat_models.openai import _convert_dict_to_message
|
||||
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@@ -46,30 +52,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -79,12 +62,58 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
model_kwargs = {
|
||||
'top_p': params.get('top_p', 1),
|
||||
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||
'presence_penalty': params.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
del params['top_p']
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||
inner_completion += token
|
||||
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
|
||||
if _function_call:
|
||||
if function_call is None:
|
||||
function_call = _function_call
|
||||
else:
|
||||
function_call["arguments"] += _function_call["arguments"]
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
"function_call": function_call,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Mapping, Any
|
||||
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -50,7 +54,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -60,12 +64,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
62
api/core/llm/streamable_chat_anthropic.py
Normal file
62
api/core/llm/streamable_chat_anthropic.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from httpx import Timeout
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||
|
||||
|
||||
class StreamableChatAnthropic(ChatAnthropic):
|
||||
"""
|
||||
Wrapper around Anthropic's large language model.
|
||||
"""
|
||||
|
||||
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
|
||||
|
||||
@root_validator()
|
||||
def prepare_params(cls, values: Dict) -> Dict:
|
||||
values['model_name'] = values.get('model')
|
||||
values['max_tokens'] = values.get('max_tokens_to_sample')
|
||||
return values
|
||||
|
||||
@handle_anthropic_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
params['model'] = params.get('model_name')
|
||||
del params['model_name']
|
||||
|
||||
params['max_tokens_to_sample'] = params.get('max_tokens')
|
||||
del params['max_tokens']
|
||||
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
return params
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
@@ -3,14 +3,18 @@ import os
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, Union, Tuple
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -48,30 +52,7 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -81,12 +62,18 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
model_kwargs = {
|
||||
'top_p': params.get('top_p', 1),
|
||||
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||
'presence_penalty': params.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
del params['top_p']
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
@@ -2,14 +2,18 @@ import os
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Any, Mapping
|
||||
from typing import Optional, List, Dict, Any, Mapping, Union, Tuple
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -49,7 +53,7 @@ class StreamableOpenAI(OpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -59,12 +63,6 @@ class StreamableOpenAI(OpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import openai
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from models.provider import ProviderName
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions
|
||||
from core.llm.provider.base import BaseProvider
|
||||
|
||||
|
||||
@@ -13,7 +14,7 @@ class Whisper:
|
||||
self.client = openai.Audio
|
||||
self.credentials = provider.get_credentials()
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def transcribe(self, file):
|
||||
return self.client.transcribe(
|
||||
model='whisper-1',
|
||||
|
||||
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import anthropic
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_anthropic_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except anthropic.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to Anthropic API.")
|
||||
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
|
||||
except anthropic.RateLimitError:
|
||||
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(f"Anthropic: {e.message}")
|
||||
except anthropic.BadRequestError as e:
|
||||
raise LLMBadRequestError(f"Anthropic: {e.message}")
|
||||
except anthropic.APIStatusError as e:
|
||||
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
|
||||
|
||||
return wrapper
|
||||
@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_llm_exceptions(func):
|
||||
def handle_openai_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def handle_llm_exceptions_async(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
|
||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
return chat_messages
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
|
||||
|
||||
301
api/core/orchestrator_rule_parser.py
Normal file
301
api/core/orchestrator_rule_parser.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
"""Parse the orchestrator rule to entities."""
|
||||
|
||||
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
|
||||
self.tenant_id = tenant_id
|
||||
self.app_model_config = app_model_config
|
||||
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
|
||||
self.dataset_retrieve_model_name = "gpt-3.5-turbo"
|
||||
|
||||
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
|
||||
-> Optional[AgentExecutor]:
|
||||
if not self.app_model_config.agent_mode_dict:
|
||||
return None
|
||||
|
||||
agent_mode_config = self.app_model_config.agent_mode_dict
|
||||
model_dict = self.app_model_config.model_dict
|
||||
|
||||
chain = None
|
||||
if agent_mode_config and agent_mode_config.get('enabled'):
|
||||
tool_configs = agent_mode_config.get('tools', [])
|
||||
agent_model_name = model_dict.get('name', 'gpt-4')
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_name=agent_model_name,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
|
||||
chain_callback.agent_callback = agent_callback
|
||||
|
||||
agent_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=agent_model_name,
|
||||
temperature=0,
|
||||
max_tokens=1500,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
|
||||
|
||||
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||
if not isinstance(agent_llm, ChatOpenAI) \
|
||||
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
|
||||
summary_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=self.agent_summary_model_name,
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
conversation_message_task=conversation_message_task,
|
||||
model_name=self.agent_summary_model_name,
|
||||
rest_tokens=rest_tokens,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
|
||||
dataset_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=self.dataset_retrieve_model_name,
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
llm=agent_llm,
|
||||
tools=tools,
|
||||
summary_llm=summary_llm,
|
||||
dataset_llm=dataset_llm,
|
||||
memory=memory,
|
||||
callbacks=[chain_callback, agent_callback],
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate"
|
||||
)
|
||||
|
||||
return AgentExecutor(agent_configuration)
|
||||
|
||||
return chain
|
||||
|
||||
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
|
||||
-> Optional[SensitiveWordAvoidanceChain]:
|
||||
"""
|
||||
Convert app sensitive word avoidance config to chain
|
||||
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not self.app_model_config.sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callbacks=callbacks,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||
model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
:param rest_tokens:
|
||||
:param tool_configs: app agent tool configs
|
||||
:param model_name:
|
||||
:param conversation_message_task:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
tools = []
|
||||
for tool_config in tool_configs:
|
||||
tool_type = list(tool_config.keys())[0]
|
||||
tool_val = list(tool_config.values())[0]
|
||||
if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
|
||||
continue
|
||||
|
||||
tool = None
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool(model_name)
|
||||
elif tool_type == "google_search":
|
||||
tool = self.to_google_search_tool()
|
||||
elif tool_type == "wikipedia":
|
||||
tool = self.to_wikipedia_tool()
|
||||
elif tool_type == "current_datetime":
|
||||
tool = self.to_current_datetime_tool()
|
||||
|
||||
if tool:
|
||||
tool.callbacks.extend(callbacks)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param rest_tokens:
|
||||
:param tool_config:
|
||||
:param conversation_message_task:
|
||||
:return:
|
||||
"""
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
return None
|
||||
|
||||
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
k=k,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
|
||||
:return:
|
||||
"""
|
||||
summary_llm = LLMBuilder.to_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
model_name=model_name,
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
tool = WebReaderTool(
|
||||
llm=summary_llm,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_google_search_tool(self) -> Optional[BaseTool]:
|
||||
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||
if not func_kwargs:
|
||||
return None
|
||||
|
||||
tool = Tool(
|
||||
name="google_search",
|
||||
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||
"when you need to search for something you don't know or when your information "
|
||||
"is not up to date. "
|
||||
"Input should be a search query.",
|
||||
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||
args_schema=OptimizedSerpAPIInput,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self) -> Optional[BaseTool]:
|
||||
tool = Tool(
|
||||
name="current_datetime",
|
||||
description="A tool when you want to get the current date, time, week, month or year, "
|
||||
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
|
||||
func=helper.get_current_datetime,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_wikipedia_tool(self) -> Optional[BaseTool]:
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
return WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||
DEFAULT_K = 2
|
||||
CONTEXT_TOKENS_PERCENT = 0.3
|
||||
processing_rule = dataset.latest_process_rule
|
||||
if not processing_rule:
|
||||
return DEFAULT_K
|
||||
|
||||
if processing_rule.mode == "custom":
|
||||
rules = processing_rule.rules_dict
|
||||
if not rules:
|
||||
return DEFAULT_K
|
||||
|
||||
segmentation = rules["segmentation"]
|
||||
segment_max_tokens = segmentation["max_tokens"]
|
||||
else:
|
||||
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||
|
||||
# when rest_tokens is less than default context tokens
|
||||
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||
return rest_tokens // segment_max_tokens
|
||||
|
||||
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||
|
||||
# when context_limit_tokens is less than default context tokens, use default_k
|
||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||
return DEFAULT_K
|
||||
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||
return context_limit_tokens // segment_max_tokens
|
||||
@@ -43,6 +43,16 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n"
|
||||
'Step 1: Understand and summarize the main content of this text.\n'
|
||||
'Step 2: What key information or concepts are mentioned in this text?\n'
|
||||
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
|
||||
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
|
||||
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
|
||||
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
)
|
||||
|
||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||
the model prompt that best suits the input.
|
||||
You will be provided with the prompt, variables, and an opening statement.
|
||||
|
||||
@@ -7,7 +7,7 @@ from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetTool(BaseTool):
|
||||
@@ -27,10 +27,11 @@ class DatasetTool(BaseTool):
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
@@ -54,13 +55,27 @@ class DatasetTool(BaseTool):
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
if segments:
|
||||
for segment in segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(segment.answer)
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
121
api/core/tool/dataset_retriever_tool.py
Normal file
121
api/core/tool/dataset_retriever_tool.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(BaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
k: int = 3
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
description += '\nID of dataset MUST be ' + dataset.id
|
||||
return cls(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, dataset_id: str, query: str) -> str:
|
||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||
match = re.search(pattern, dataset_id, re.IGNORECASE)
|
||||
if match:
|
||||
dataset_id = match.group()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
if self.k > 0:
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
}
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
for segment in segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
63
api/core/tool/provider/base.py
Normal file
63
api/core/tool/provider/base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class BaseToolProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_validate(self, credentials: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and tool_name.
|
||||
"""
|
||||
query = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == self.tenant_id,
|
||||
ToolProvider.tool_name == self.get_provider_name().value
|
||||
)
|
||||
|
||||
if must_enabled:
|
||||
query = query.filter(ToolProvider.is_enabled == True)
|
||||
|
||||
return query.first()
|
||||
|
||||
def encrypt_token(self, token) -> str:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
||||
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
if obfuscated:
|
||||
return self._obfuscated_token(token)
|
||||
|
||||
return token
|
||||
|
||||
def _obfuscated_token(self, token: str) -> str:
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
2
api/core/tool/provider/errors.py
Normal file
2
api/core/tool/provider/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class ToolValidateFailedError(Exception):
|
||||
description = "Tool Provider Validate failed"
|
||||
77
api/core/tool/provider/serpapi_provider.py
Normal file
77
api/core/tool/provider/serpapi_provider.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
|
||||
from models.tool import ToolProviderName
|
||||
|
||||
|
||||
class SerpAPIToolProvider(BaseToolProvider):
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
"""
|
||||
Returns the name of the provider.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return ToolProviderName.SERPAPI
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for SerpAPI as a dictionary.
|
||||
|
||||
:param obfuscated: obfuscate credentials if True
|
||||
:return:
|
||||
"""
|
||||
tool_provider = self.get_provider(must_enabled=True)
|
||||
if not tool_provider:
|
||||
return None
|
||||
|
||||
credentials = tool_provider.credentials
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
if credentials.get('api_key'):
|
||||
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
|
||||
|
||||
return credentials
|
||||
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials function kwargs as a dictionary.
|
||||
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
return {
|
||||
'serpapi_api_key': credentials.get('api_key')
|
||||
}
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
if 'api_key' not in credentials or not credentials.get('api_key'):
|
||||
raise ToolValidateFailedError("SerpAPI api_key is required.")
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
try:
|
||||
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
|
||||
except Exception as e:
|
||||
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
|
||||
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
|
||||
return credentials
|
||||
43
api/core/tool/provider/tool_provider_service.py
Normal file
43
api/core/tool/provider/tool_provider_service.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
|
||||
|
||||
class ToolProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self._init_provider(tenant_id, provider_name)
|
||||
|
||||
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
|
||||
if provider_name == 'serpapi':
|
||||
return SerpAPIToolProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('tool provider {} not found'.format(provider_name))
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for Tool as a dictionary.
|
||||
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.get_credentials(obfuscated)
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.credentials_validate(credentials)
|
||||
|
||||
def encrypt_credentials(self, credentials: dict):
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.encrypt_credentials(credentials)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user