mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 10:52:57 +08:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2a5f8ba1a | ||
|
|
8e11200306 | ||
|
|
7599f79a17 | ||
|
|
510389909c | ||
|
|
2c6e00174b | ||
|
|
24f3456990 | ||
|
|
20514ff288 | ||
|
|
381d255290 | ||
|
|
7f320f9146 | ||
|
|
cd51d3323b | ||
|
|
004b3caa43 | ||
|
|
dbe10799e3 | ||
|
|
054ba88434 | ||
|
|
da82a11b26 | ||
|
|
fec607db81 | ||
|
|
397a92f2ee | ||
|
|
b91e226063 | ||
|
|
da5782df92 | ||
|
|
9af0da4450 | ||
|
|
d49ac1e4ac |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -109,6 +109,7 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
@@ -8,13 +8,19 @@ EDITION=SELF_HOSTED
|
||||
SECRET_KEY=
|
||||
|
||||
# Console API base URL
|
||||
CONSOLE_URL=http://127.0.0.1:5001
|
||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Console frontend web base URL
|
||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Service API base URL
|
||||
API_URL=http://127.0.0.1:5001
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP base URL
|
||||
APP_URL=http://127.0.0.1:3000
|
||||
# Web APP API base URL
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP frontend web base URL
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
@@ -79,6 +85,11 @@ WEAVIATE_BATCH_SIZE=100
|
||||
QDRANT_URL=path:storage/qdrant
|
||||
QDRANT_API_KEY=your-qdrant-api-key
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@ LABEL maintainer="takatost@gmail.com"
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
ENV CONSOLE_URL http://127.0.0.1:5001
|
||||
ENV API_URL http://127.0.0.1:5001
|
||||
ENV APP_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_API_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_WEB_URL http://127.0.0.1:3000
|
||||
ENV SERVICE_API_URL http://127.0.0.1:5001
|
||||
ENV APP_API_URL http://127.0.0.1:5001
|
||||
ENV APP_WEB_URL http://127.0.0.1:3000
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
|
||||
16
api/app.py
16
api/app.py
@@ -2,6 +2,8 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
@@ -15,7 +17,7 @@ import flask_login
|
||||
from flask_cors import CORS
|
||||
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage
|
||||
ext_database, ext_storage, ext_mail
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@@ -27,7 +29,7 @@ from events import event_handlers
|
||||
import core
|
||||
from config import Config, CloudEditionConfig
|
||||
from commands import register_commands
|
||||
from models.account import TenantAccountJoin
|
||||
from models.account import TenantAccountJoin, AccountStatus
|
||||
from models.model import Account, EndUser, App
|
||||
|
||||
import warnings
|
||||
@@ -83,6 +85,7 @@ def initialize_extensions(app):
|
||||
ext_celery.init_app(app)
|
||||
ext_session.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
|
||||
|
||||
@@ -100,6 +103,9 @@ def load_user(user_id):
|
||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if account:
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise Forbidden('Account is banned or closed.')
|
||||
|
||||
workspace_id = session.get('workspace_id')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
@@ -149,13 +155,17 @@ def register_blueprints(app):
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.console import bp as console_app_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(web_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization'],
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
|
||||
@@ -18,7 +18,8 @@ from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider
|
||||
from models.provider import Provider, ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
|
||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for tenant in tenants:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(generate_invitation_codes)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
|
||||
@@ -28,9 +28,11 @@ DEFAULTS = {
|
||||
'SESSION_REDIS_USE_SSL': 'False',
|
||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||
'CONSOLE_URL': 'https://cloud.dify.ai',
|
||||
'API_URL': 'https://api.dify.ai',
|
||||
'APP_URL': 'https://udify.app',
|
||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||
'CONSOLE_API_URL': 'https://cloud.dify.ai',
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@@ -48,7 +50,10 @@ DEFAULTS = {
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'DEFAULT_LLM_PROVIDER': 'openai'
|
||||
'DEFAULT_LLM_PROVIDER': 'openai',
|
||||
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
|
||||
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
|
||||
'TENANT_DOCUMENT_COUNT': 100
|
||||
}
|
||||
|
||||
|
||||
@@ -76,10 +81,15 @@ class Config:
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.7"
|
||||
self.CURRENT_VERSION = "0.3.9"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -147,10 +157,15 @@ class Config:
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
@@ -179,6 +194,10 @@ class Config:
|
||||
|
||||
# hosted provider credentials
|
||||
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
||||
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
|
||||
|
||||
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
|
||||
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
@@ -195,6 +214,8 @@ class Config:
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from . import setup, version, apikey, admin
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = 'provider_quota_exceeded'
|
||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
code = 400
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
|
||||
args['audiences'],
|
||||
args['hoping_to_solve']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
75
api/controllers/console/auth/activate.py
Normal file
75
api/controllers/console/auth/activate.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, supported_language, timezone
|
||||
from libs.password import valid_password, hash_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import RegisterService
|
||||
|
||||
|
||||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == args['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
return {'is_valid': account is not None, 'workspace_name': tenant.name}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if account is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
account.name = args['name']
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args['password'], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args['interface_language']
|
||||
account.timezone = args['timezone']
|
||||
account.interface_theme = 'light'
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
||||
api.add_resource(ActivateApi, '/activate')
|
||||
@@ -20,7 +20,7 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
@@ -42,7 +42,7 @@ class OAuthDataSource(Resource):
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
@@ -66,12 +66,12 @@ class OAuthDataSourceCallback(Resource):
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source={error}')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=access_denied')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
|
||||
@@ -20,13 +20,13 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'GITHUB_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/github')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
|
||||
|
||||
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GOOGLE_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/google')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'github': github_oauth,
|
||||
@@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
||||
flask_login.login_user(account, remember=True)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
|
||||
document_data=args,
|
||||
account=current_user
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -18,3 +18,9 @@ class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = 'already_activate'
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -6,22 +6,23 @@ from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \
|
||||
RepeatPasswordNotMatchError
|
||||
RepeatPasswordNotMatchError, CurrentPasswordIncorrectError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import TimestampField, supported_language, timezone
|
||||
from extensions.ext_database import db
|
||||
from models.account import InvitationCode, AccountIntegrate
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'is_password_set': fields.Boolean,
|
||||
'interface_language': fields.String,
|
||||
'interface_theme': fields.String,
|
||||
'timezone': fields.String,
|
||||
@@ -194,8 +195,11 @@ class AccountPasswordApi(Resource):
|
||||
if args['new_password'] != args['repeat_new_password']:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
try:
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ class RepeatPasswordNotMatchError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class CurrentPasswordIncorrectError(BaseHTTPException):
|
||||
error_code = 'current_password_incorrect'
|
||||
description = "Current password is incorrect."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderRequestFailedError(BaseHTTPException):
|
||||
error_code = 'provider_request_failed'
|
||||
description = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
@@ -60,7 +60,8 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
|
||||
try:
|
||||
RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter)
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == args['email']).first()
|
||||
@@ -78,7 +79,16 @@ class MemberInviteEmailApi(Resource):
|
||||
|
||||
# todo:413
|
||||
|
||||
return {'result': 'success', 'account': account}, 201
|
||||
return {
|
||||
'result': 'success',
|
||||
'account': account,
|
||||
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
str(current_user.current_tenant_id),
|
||||
invitee_email,
|
||||
token
|
||||
)
|
||||
}, 201
|
||||
|
||||
|
||||
class MemberCancelInviteApi(Resource):
|
||||
@@ -88,7 +98,7 @@ class MemberCancelInviteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
member = Account.query.get(str(member_id))
|
||||
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, abort
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
|
||||
ProviderService.init_supported_provider(current_user.current_tenant)
|
||||
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
||||
|
||||
provider_list = [
|
||||
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
|
||||
'quota_used': p.quota_used
|
||||
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
||||
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
||||
ProviderName(p.provider_name))
|
||||
ProviderName(p.provider_name), only_custom=True)
|
||||
if p.provider_type == ProviderType.CUSTOM.value else None
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
|
||||
is_valid=token_is_valid)
|
||||
db.session.add(provider_model)
|
||||
|
||||
if provider_model.is_valid:
|
||||
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
|
||||
other_providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
|
||||
Provider.provider_name != provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
||||
|
||||
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: remove this when the provider is supported
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||
|
||||
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
|
||||
provider_model.is_valid = args['is_enabled']
|
||||
db.session.commit()
|
||||
elif not provider_model:
|
||||
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
|
||||
else:
|
||||
quota_limit = 0
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
provider,
|
||||
quota_limit,
|
||||
args['is_enabled']
|
||||
)
|
||||
else:
|
||||
abort(403)
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -56,16 +57,14 @@ class ConversationDetailApi(AppApiResource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
return {"result": "success"}, 204
|
||||
return {"result": "success"}
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -95,3 +94,4 @@ class ConversationRenameApi(AppApiResource):
|
||||
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
|
||||
api.add_resource(ConversationApi, '/conversations')
|
||||
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>', endpoint='conversation')
|
||||
api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')
|
||||
|
||||
@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
|
||||
dataset_process_rule=dataset.latest_process_rule,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message, audio
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
|
||||
@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
64
api/controllers/web/passport.py
Normal file
64
api/controllers/web/passport.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from controllers.web import api
|
||||
from flask_restful import Resource
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Unauthorized, NotFound
|
||||
from models.model import Site, EndUser, App
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
|
||||
class PassportResource(Resource):
|
||||
"""Base resource for passport."""
|
||||
def get(self):
|
||||
app_id = request.headers.get('X-App-Code')
|
||||
if app_id is None:
|
||||
raise Unauthorized('X-App-Code header is missing.')
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == app_id,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
# get app from db and check if it is normal and enable_site
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
'sub': 'Web API Passport',
|
||||
'app_id': site.app_id,
|
||||
'end_user_id': end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
'access_token': tk,
|
||||
}
|
||||
|
||||
api.add_resource(PassportResource, '/passport')
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
while True:
|
||||
session_id = str(uuid.uuid4())
|
||||
existing_count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
if existing_count == 0:
|
||||
return session_id
|
||||
@@ -1,110 +1,48 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from flask import request, session
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Site, EndUser
|
||||
from models.model import App, EndUser
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
def validate_token(view=None):
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
site = validate_and_get_site()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
if app_model.status != 'normal':
|
||||
raise NotFound()
|
||||
|
||||
if not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = create_or_update_end_user_for_session(app_model)
|
||||
app_model, end_user = decode_jwt_token()
|
||||
|
||||
return view(app_model, end_user, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_site():
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
def decode_jwt_token():
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized('Authorization header is missing.')
|
||||
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == auth_token,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not site:
|
||||
decoded = PassportService().verify(tk)
|
||||
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
def create_or_update_end_user_for_session(app_model):
|
||||
"""
|
||||
Create or update session terminal based on session ID.
|
||||
"""
|
||||
if 'session_id' not in session:
|
||||
session['session_id'] = generate_session_id()
|
||||
|
||||
session_id = session.get('session_id')
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.type == 'browser'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=session_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
count = 1
|
||||
session_id = ''
|
||||
while count != 0:
|
||||
session_id = str(uuid.uuid4())
|
||||
count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
|
||||
return session_id
|
||||
|
||||
return app_model, end_user
|
||||
|
||||
class WebApiResource(Resource):
|
||||
method_decorators = [validate_token]
|
||||
method_decorators = [validate_jwt_token]
|
||||
|
||||
@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedAnthropicCredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
anthropic: Optional[HostedAnthropicCredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
@@ -26,3 +31,6 @@ def init_app(app: Flask):
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
|
||||
if app.config.get("ANTHROPIC_API_KEY"):
|
||||
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
|
||||
@@ -118,6 +118,7 @@ class Completion:
|
||||
prompt, stop_words = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
@@ -129,6 +130,7 @@ class Completion:
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode=mode
|
||||
)
|
||||
@@ -138,7 +140,8 @@ class Completion:
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
@@ -151,10 +154,11 @@ class Completion:
|
||||
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following CONTEXT as your learned knowledge:
|
||||
[CONTEXT]
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||
[CONTEXT]
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\nHuman: {{query}}\nAI: "
|
||||
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
||||
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
||||
- memory.llm.max_tokens - curr_message_tokens
|
||||
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
||||
model_name = model['name']
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
||||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
|
||||
# if histories_param not in human_inputs:
|
||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||
|
||||
human_message_prompt += "\n\n" + histories
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>"
|
||||
human_message_prompt += histories + "</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
|
||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
max_tokens = llm.max_tokens
|
||||
model_name = app_model_config.model_dict.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt, _ = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
|
||||
max_tokens = final_llm.max_tokens
|
||||
model_name = model.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
|
||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||
else:
|
||||
prompt_tokens = final_llm.get_messages_tokens(prompt)
|
||||
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
model=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
|
||||
original_prompt, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode='completion'
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||
'claude-2': 'anthropic', # 100,000 tokens
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
@@ -10,10 +12,13 @@ models = {
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
|
||||
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||
'whisper-1': 'openai'
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
@@ -23,17 +28,21 @@ max_context_token_length = {
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191
|
||||
'text-embedding-ada-002': 8191,
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
@@ -52,6 +61,14 @@ models_by_mode = {
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': Decimal('0.00163'),
|
||||
'completion': Decimal('0.00551'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': Decimal('0.01102'),
|
||||
'completion': Decimal('0.03268'),
|
||||
},
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
|
||||
@@ -56,7 +56,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
def init(self):
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
|
||||
self.model_dict['provider'] = provider_name
|
||||
|
||||
override_model_configs = None
|
||||
@@ -89,7 +89,7 @@ class ConversationMessageTask:
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||
system_instruction_tokens = llm.get_messages_tokens([system_message])
|
||||
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
@@ -185,6 +185,7 @@ class ConversationMessageTask:
|
||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.app.tenant_id,
|
||||
Provider.provider_name == provider.provider_name,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + 1})
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
|
||||
text_embeddings.extend(embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
@handle_openai_exceptions
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
|
||||
@@ -23,6 +23,10 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
prompt = prompt.format(query=query)
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
@@ -52,7 +56,17 @@ class LLMGenerator:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
|
||||
if len(message.query) > 2000:
|
||||
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||
else:
|
||||
query = message.query
|
||||
|
||||
if len(message.answer) > 2000:
|
||||
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||
else:
|
||||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class IndexBuilder:
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
|
||||
@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from models.provider import ProviderType
|
||||
from models.provider import ProviderType, ProviderName
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
@@ -32,43 +33,43 @@ class LLMBuilder:
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id)
|
||||
provider = cls.get_default_provider(tenant_id, model_name)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
llm_cls = None
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == 'openai':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableChatOpenAI
|
||||
else:
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
llm_cls = StreamableChatAnthropic
|
||||
elif mode == 'completion':
|
||||
if provider == 'openai':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
else:
|
||||
|
||||
if not llm_cls:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
|
||||
model_kwargs = {
|
||||
'model_name': model_name,
|
||||
'temperature': kwargs.get('temperature', 0),
|
||||
'max_tokens': kwargs.get('max_tokens', 256),
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
'callbacks': kwargs.get('callbacks', None),
|
||||
'streaming': kwargs.get('streaming', False),
|
||||
}
|
||||
|
||||
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
||||
model_kwargs.update(model_credentials)
|
||||
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
**model_extras_kwargs,
|
||||
callbacks=kwargs.get('callbacks', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
)
|
||||
return llm_cls(**model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
@@ -118,14 +119,29 @@ class LLMBuilder:
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_default_provider(cls, tenant_id: str) -> str:
|
||||
provider = BaseProvider.get_valid_provider(tenant_id)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
|
||||
provider_name = llm_constant.models[model_name]
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider_name = 'openai'
|
||||
else:
|
||||
provider_name = provider.provider_name
|
||||
if provider_name == 'openai':
|
||||
# get the default provider (openai / azure_openai) for the tenant
|
||||
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {provider_name} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider_name = 'openai'
|
||||
else:
|
||||
provider_name = provider.provider_name
|
||||
|
||||
return provider_name
|
||||
|
||||
@@ -1,23 +1,138 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import anthropic
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName, ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
},
|
||||
]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
||||
"""
|
||||
return {
|
||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
return self.get_provider_api_key(model_id=model_id)
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
return ProviderName.ANTHROPIC
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('anthropic_api_key'):
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# check OpenAI / Azure OpenAI credential is valid
|
||||
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
if quota_used >= quota_limit:
|
||||
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
|
||||
f"please configure OpenAI or Azure OpenAI provider first.")
|
||||
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'anthropic_api_key' not in config:
|
||||
raise ValueError('anthropic_api_key must be provided.')
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
anthropic_api_key=config['anthropic_api_key'],
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
|
||||
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
|
||||
|
||||
@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
@@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(prefer_custom)
|
||||
provider = self.get_provider(only_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
||||
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
|
||||
|
||||
@classmethod
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
|
||||
Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
If both CUSTOM and System providers exist.
|
||||
"""
|
||||
query = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id
|
||||
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
|
||||
if provider_name:
|
||||
query = query.filter(Provider.provider_name == provider_name)
|
||||
|
||||
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
if only_custom:
|
||||
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
providers = query.order_by(Provider.provider_type.asc()).all()
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
return provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
return provider
|
||||
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
elif system_provider:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> str:
|
||||
if self.get_provider_name() != ProviderName.OPENAI:
|
||||
raise ProviderTokenNotInitError()
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = ''
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ class LLMProviderService:
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated)
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
|
||||
|
||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
return self.provider.get_provider(prefer_custom)
|
||||
def get_provider_db_record(self) -> Optional[Provider]:
|
||||
return self.provider.get_provider()
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional, Union
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
model_kwargs = {
|
||||
'top_p': params.get('top_p', 1),
|
||||
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||
'presence_penalty': params.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
del params['top_p']
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
39
api/core/llm/streamable_chat_anthropic.py
Normal file
39
api/core/llm/streamable_chat_anthropic.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
|
||||
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||
|
||||
|
||||
class StreamableChatAnthropic(ChatAnthropic):
|
||||
"""
|
||||
Wrapper around Anthropic's large language model.
|
||||
"""
|
||||
|
||||
@handle_anthropic_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
params['model'] = params.get('model_name')
|
||||
del params['model_name']
|
||||
|
||||
params['max_tokens_to_sample'] = params.get('max_tokens')
|
||||
del params['max_tokens']
|
||||
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
return params
|
||||
@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
model_kwargs = {
|
||||
'top_p': params.get('top_p', 1),
|
||||
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||
'presence_penalty': params.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
del params['top_p']
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
params['model_kwargs'] = model_kwargs
|
||||
|
||||
return params
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import openai
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from models.provider import ProviderName
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions
|
||||
from core.llm.provider.base import BaseProvider
|
||||
|
||||
|
||||
@@ -13,7 +14,7 @@ class Whisper:
|
||||
self.client = openai.Audio
|
||||
self.credentials = provider.get_credentials()
|
||||
|
||||
@handle_llm_exceptions
|
||||
@handle_openai_exceptions
|
||||
def transcribe(self, file):
|
||||
return self.client.transcribe(
|
||||
model='whisper-1',
|
||||
|
||||
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import anthropic
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_anthropic_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except anthropic.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to Anthropic API.")
|
||||
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
|
||||
except anthropic.RateLimitError:
|
||||
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(f"Anthropic: {e.message}")
|
||||
except anthropic.BadRequestError as e:
|
||||
raise LLMBadRequestError(f"Anthropic: {e.message}")
|
||||
except anthropic.APIStatusError as e:
|
||||
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
|
||||
|
||||
return wrapper
|
||||
@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_llm_exceptions(func):
|
||||
def handle_openai_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def handle_llm_exceptions_async(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
|
||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
return chat_messages
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from flask import current_app
|
||||
|
||||
from events.tenant_event import tenant_was_updated
|
||||
from models.provider import ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
|
||||
def handle(sender, **kwargs):
|
||||
tenant = sender
|
||||
if tenant.status == 'normal':
|
||||
ProviderService.create_system_provider(tenant)
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.OPENAI.value,
|
||||
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from flask import current_app
|
||||
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.provider import ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
|
||||
def handle(sender, **kwargs):
|
||||
tenant = sender
|
||||
if tenant.status == 'normal':
|
||||
ProviderService.create_system_provider(tenant)
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.OPENAI.value,
|
||||
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
|
||||
61
api/extensions/ext_mail.py
Normal file
61
api/extensions/ext_mail.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Optional
|
||||
|
||||
import resend
|
||||
from flask import Flask
|
||||
|
||||
|
||||
class Mail:
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._default_send_from = None
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if app.config.get('MAIL_TYPE'):
|
||||
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
|
||||
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
|
||||
|
||||
if app.config.get('MAIL_TYPE') == 'resend':
|
||||
api_key = app.config.get('RESEND_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError('RESEND_API_KEY is not set')
|
||||
|
||||
resend.api_key = api_key
|
||||
self._client = resend.Emails
|
||||
else:
|
||||
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
|
||||
|
||||
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
||||
if not self._client:
|
||||
raise ValueError('Mail client is not initialized')
|
||||
|
||||
if not from_ and self._default_send_from:
|
||||
from_ = self._default_send_from
|
||||
|
||||
if not from_:
|
||||
raise ValueError('mail from is not set')
|
||||
|
||||
if not to:
|
||||
raise ValueError('mail to is not set')
|
||||
|
||||
if not subject:
|
||||
raise ValueError('mail subject is not set')
|
||||
|
||||
if not html:
|
||||
raise ValueError('mail html is not set')
|
||||
|
||||
self._client.send({
|
||||
"from": from_,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html
|
||||
})
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
mail.init_app(app)
|
||||
|
||||
|
||||
mail = Mail()
|
||||
20
api/libs/passport.py
Normal file
20
api/libs/passport.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import jwt
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from flask import current_app
|
||||
class PassportService:
|
||||
def __init__(self):
|
||||
self.sk = current_app.config.get('SECRET_KEY')
|
||||
|
||||
def issue(self, payload):
|
||||
return jwt.encode(payload, self.sk, algorithm='HS256')
|
||||
|
||||
def verify(self, token):
|
||||
try:
|
||||
return jwt.decode(token, self.sk, algorithms=['HS256'])
|
||||
except jwt.exceptions.InvalidSignatureError:
|
||||
raise Unauthorized('Invalid token signature.')
|
||||
except jwt.exceptions.DecodeError:
|
||||
raise Unauthorized('Invalid token.')
|
||||
except jwt.exceptions.ExpiredSignatureError:
|
||||
raise Unauthorized('Token has expired.')
|
||||
@@ -38,6 +38,10 @@ class Account(UserMixin, db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
||||
@property
|
||||
def current_tenant(self):
|
||||
return self._current_tenant
|
||||
|
||||
@@ -56,7 +56,8 @@ class App(db.Model):
|
||||
|
||||
@property
|
||||
def api_base_url(self):
|
||||
return (current_app.config['API_URL'] if current_app.config['API_URL'] else request.host_url.rstrip('/')) + '/v1'
|
||||
return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
|
||||
@property
|
||||
def tenant(self):
|
||||
@@ -515,7 +516,7 @@ class Site(db.Model):
|
||||
|
||||
@property
|
||||
def app_base_url(self):
|
||||
return (current_app.config['APP_URL'] if current_app.config['APP_URL'] else request.host_url.rstrip('/'))
|
||||
return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
|
||||
|
||||
|
||||
class ApiToken(db.Model):
|
||||
|
||||
@@ -10,7 +10,7 @@ flask-session2==1.3.1
|
||||
flask-cors==3.0.10
|
||||
gunicorn~=20.1.0
|
||||
gevent~=22.10.2
|
||||
langchain==0.0.209
|
||||
langchain==0.0.230
|
||||
openai~=0.27.5
|
||||
psycopg2-binary~=2.9.6
|
||||
pycryptodome==3.17
|
||||
@@ -21,7 +21,7 @@ Authlib==1.2.0
|
||||
boto3~=1.26.123
|
||||
tenacity==8.2.2
|
||||
cachetools~=5.3.0
|
||||
weaviate-client~=3.16.2
|
||||
weaviate-client~=3.21.0
|
||||
qdrant_client~=1.1.6
|
||||
mailchimp-transactional~=1.0.50
|
||||
scikit-learn==1.2.2
|
||||
@@ -32,4 +32,7 @@ redis~=4.5.4
|
||||
openpyxl==3.1.2
|
||||
chardet~=5.1.0
|
||||
docx2txt==0.8
|
||||
pypdfium2==4.16.0
|
||||
pypdfium2==4.16.0
|
||||
resend~=0.5.1
|
||||
pyjwt~=2.6.0
|
||||
anthropic~=0.3.4
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
|
||||
from flask import session
|
||||
from sqlalchemy import func
|
||||
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
|
||||
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
|
||||
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
|
||||
@@ -16,6 +19,7 @@ from libs.helper import get_remote_ip
|
||||
from libs.password import compare_password, hash_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import *
|
||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||
|
||||
|
||||
class AccountService:
|
||||
@@ -48,12 +52,18 @@ class AccountService:
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
"""update account password"""
|
||||
# todo: split validation and update
|
||||
if account.password and not compare_password(password, account.password, account.password_salt):
|
||||
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
||||
password_hashed = hash_password(new_password, account.password_salt)
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@@ -283,8 +293,6 @@ class TenantService:
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
# todo: check permission
|
||||
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
@@ -293,6 +301,12 @@ class TenantService:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
db.session.delete(ta)
|
||||
|
||||
account.initialized_at = None
|
||||
account.status = AccountStatus.PENDING.value
|
||||
account.password = None
|
||||
account.password_salt = None
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
@@ -332,8 +346,8 @@ class TenantService:
|
||||
|
||||
class RegisterService:
|
||||
|
||||
@staticmethod
|
||||
def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
@classmethod
|
||||
def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
@@ -359,9 +373,9 @@ class RegisterService:
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> TenantAccountJoin:
|
||||
@classmethod
|
||||
def invite_new_member(cls, tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> str:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
@@ -380,5 +394,71 @@ class RegisterService:
|
||||
if ta:
|
||||
raise AccountAlreadyInTenantError("Account already in tenant.")
|
||||
|
||||
ta = TenantService.create_tenant_member(tenant, account, role)
|
||||
return ta
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
||||
token = cls.generate_invite_token(tenant, account)
|
||||
|
||||
# send email
|
||||
send_invite_member_mail_task.delay(
|
||||
to=email,
|
||||
token=cls.generate_invite_token(tenant, account),
|
||||
inviter_name=inviter.name if inviter else 'Dify',
|
||||
workspace_id=tenant.id,
|
||||
workspace_name=tenant.name,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
email_hash = sha256(account.email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token)
|
||||
redis_client.setex(cache_key, 3600, str(account.id))
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]:
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == workspace_id,
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
account_id = cls._get_account_id_by_invite_token(workspace_id, email, token)
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
account = tenant_account[0]
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account_id != str(account.id):
|
||||
return None
|
||||
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
account_id = redis_client.get(cache_key)
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
return account_id.decode('utf-8')
|
||||
|
||||
@@ -6,6 +6,30 @@ from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
MODEL_PROVIDERS = [
|
||||
'openai',
|
||||
'anthropic',
|
||||
]
|
||||
|
||||
MODELS_BY_APP_MODE = {
|
||||
'chat': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
'text-davinci-003',
|
||||
]
|
||||
}
|
||||
|
||||
class AppModelConfigService:
|
||||
@staticmethod
|
||||
@@ -125,7 +149,7 @@ class AppModelConfigService:
|
||||
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
|
||||
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id)
|
||||
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
|
||||
|
||||
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
|
||||
raise ValueError("provider not support speech to text")
|
||||
@@ -153,14 +177,14 @@ class AppModelConfigService:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
|
||||
raise ValueError("model.provider must be 'openai'")
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
|
||||
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
|
||||
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.completion_params
|
||||
|
||||
@@ -6,7 +6,8 @@ from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServ
|
||||
from core.llm.whisper import Whisper
|
||||
from models.provider import ProviderName
|
||||
|
||||
FILE_SIZE_LIMIT = 1 * 1024 * 1024
|
||||
FILE_SIZE = 15
|
||||
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
|
||||
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
|
||||
|
||||
class AudioService:
|
||||
@@ -23,21 +24,16 @@ class AudioService:
|
||||
file_size = len(file_content)
|
||||
|
||||
if file_size > FILE_SIZE_LIMIT:
|
||||
message = f"({file_size} > {FILE_SIZE_LIMIT})"
|
||||
message = f"Audio size larger than {FILE_SIZE} mb"
|
||||
raise AudioTooLargeServiceError(message)
|
||||
|
||||
provider_name = LLMBuilder.get_default_provider(tenant_id)
|
||||
provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
|
||||
if provider_name != ProviderName.OPENAI.value:
|
||||
raise ProviderNotSupportSpeechToTextServiceError('haha')
|
||||
raise ProviderNotSupportSpeechToTextServiceError()
|
||||
|
||||
provider_service = LLMProviderService(tenant_id, provider_name)
|
||||
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = 'temp.wav'
|
||||
buffer.name = 'temp.mp3'
|
||||
|
||||
return Whisper(provider_service.provider).transcribe(buffer)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ import datetime
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, List
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -374,6 +377,12 @@ class DocumentService:
|
||||
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = document_data["data_source"]["type"]
|
||||
@@ -521,6 +530,14 @@ class DocumentService:
|
||||
)
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_documents_count():
|
||||
documents_count = Document.query.filter(Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id).count()
|
||||
return documents_count
|
||||
|
||||
@staticmethod
|
||||
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
@@ -616,6 +633,12 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -1,23 +1,13 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
class NoAudioUploadedServiceError(BaseServiceError):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
class NoAudioUploadedServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AudioTooLargeServiceError(BaseServiceError):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
class AudioTooLargeServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedAudioTypeServiceError(BaseServiceError):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
class UnsupportedAudioTypeServiceError(Exception):
|
||||
pass
|
||||
|
||||
class ProviderNotSupportSpeechToTextServiceError(BaseServiceError):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text. {message}"
|
||||
code = 400
|
||||
class ProviderNotSupportSpeechToTextServiceError(Exception):
|
||||
pass
|
||||
@@ -31,7 +31,7 @@ class HitTestingService:
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
|
||||
@@ -10,50 +10,40 @@ from models.provider import *
|
||||
class ProviderService:
|
||||
|
||||
@staticmethod
|
||||
def init_supported_provider(tenant, edition):
|
||||
def init_supported_provider(tenant):
|
||||
"""Initialize the model provider, check whether the supported provider has a record"""
|
||||
|
||||
providers = Provider.query.filter_by(tenant_id=tenant.id).all()
|
||||
need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
|
||||
|
||||
openai_provider_exists = False
|
||||
azure_openai_provider_exists = False
|
||||
|
||||
# TODO: The cloud version needs to construct the data of the SYSTEM type
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(need_init_provider_names)
|
||||
).all()
|
||||
|
||||
exists_provider_names = []
|
||||
for provider in providers:
|
||||
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
||||
openai_provider_exists = True
|
||||
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
||||
azure_openai_provider_exists = True
|
||||
exists_provider_names.append(provider.provider_name)
|
||||
|
||||
# Initialize the model provider, check whether the supported provider has a record
|
||||
not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
|
||||
|
||||
# Create default providers if they don't exist
|
||||
if not openai_provider_exists:
|
||||
openai_provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=ProviderName.OPENAI.value,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(openai_provider)
|
||||
if not_exists_provider_names:
|
||||
# Initialize the model provider, check whether the supported provider has a record
|
||||
for provider_name in not_exists_provider_names:
|
||||
provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
|
||||
if not azure_openai_provider_exists:
|
||||
azure_openai_provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=ProviderName.AZURE_OPENAI.value,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(azure_openai_provider)
|
||||
|
||||
if not openai_provider_exists or not azure_openai_provider_exists:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName):
|
||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_provider_configs(obfuscated=True)
|
||||
return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
|
||||
|
||||
@staticmethod
|
||||
def get_token_type(tenant, provider_name: ProviderName):
|
||||
@@ -73,7 +63,7 @@ class ProviderService:
|
||||
return llm_provider_service.get_encrypted_token(configs)
|
||||
|
||||
@staticmethod
|
||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
|
||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
|
||||
is_valid: bool = True):
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return
|
||||
@@ -90,7 +80,7 @@ class ProviderService:
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=200,
|
||||
quota_limit=quota_limit,
|
||||
encrypted_config='',
|
||||
is_valid=is_valid,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class WorkspaceService:
|
||||
@@ -33,7 +33,7 @@ class WorkspaceService:
|
||||
if provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
if provider.is_valid:
|
||||
if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
|
||||
if system_provider and not custom_provider:
|
||||
|
||||
52
api/tasks/mail_invite_member_task.py
Normal file
52
api/tasks/mail_invite_member_task.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
|
||||
|
||||
@shared_task
|
||||
def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspace_id: str, workspace_name: str):
|
||||
"""
|
||||
Async Send invite member mail
|
||||
:param to
|
||||
:param token
|
||||
:param inviter_name
|
||||
:param workspace_id
|
||||
:param workspace_name
|
||||
|
||||
Usage: send_invite_member_mail_task.delay(to, token, inviter_name, workspace_id, workspace_name)
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name),
|
||||
fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
mail.send(
|
||||
to=to,
|
||||
subject="{} invited you to join {}".format(inviter_name, workspace_name),
|
||||
html="""<p>Hi there,</p>
|
||||
<p>{inviter_name} invited you to join {workspace_name}.</p>
|
||||
<p>Click <a href="{url}">here</a> to join.</p>
|
||||
<p>Thanks,</p>
|
||||
<p>Dify Team</p>""".format(inviter_name=inviter_name, workspace_name=workspace_name,
|
||||
url='{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
workspace_id,
|
||||
to,
|
||||
token)
|
||||
)
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Send invite member mail to {} failed".format(to))
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.7
|
||||
image: langgenius/dify-api:0.3.9
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -11,18 +11,26 @@ services:
|
||||
LOG_LEVEL: INFO
|
||||
# A secret key that is used for securely signing the session cookie and encrypting sensitive information on the database. You can generate a strong key using `openssl rand -base64 42`.
|
||||
SECRET_KEY: sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U
|
||||
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
|
||||
# The base URL of console application web frontend, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai
|
||||
CONSOLE_URL: ''
|
||||
CONSOLE_WEB_URL: ''
|
||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai
|
||||
CONSOLE_API_URL: ''
|
||||
# The URL for Service API endpoints,refers to the base URL of the current API service if api domain is
|
||||
# different from console domain.
|
||||
# example: http://api.dify.ai
|
||||
API_URL: ''
|
||||
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
SERVICE_API_URL: ''
|
||||
# The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_URL: ''
|
||||
APP_API_URL: ''
|
||||
# The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_WEB_URL: ''
|
||||
# When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
|
||||
MIGRATION_ENABLED: 'true'
|
||||
# The configurations of postgres database connection.
|
||||
@@ -93,6 +101,12 @@ services:
|
||||
QDRANT_URL: 'https://your-qdrant-cluster-url.qdrant.tech/'
|
||||
# The Qdrant API key.
|
||||
QDRANT_API_KEY: 'ak-difyai'
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply <no-reply@dify.ai>)'
|
||||
# the api-key for resend (https://resend.com)
|
||||
RESEND_API_KEY: ''
|
||||
# The DSN for Sentry error reporting. If not set, Sentry error reporting will be disabled.
|
||||
SENTRY_DSN: ''
|
||||
# The sample rate for Sentry events. Default: `1.0`
|
||||
@@ -110,7 +124,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.7
|
||||
image: langgenius/dify-api:0.3.9
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -146,6 +160,12 @@ services:
|
||||
VECTOR_STORE: weaviate
|
||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||
WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
MAIL_DEFAULT_SEND_FROM: 'YOUR EMAIL FROM (eg: no-reply <no-reply@dify.ai>)'
|
||||
# the api-key for resend (https://resend.com)
|
||||
RESEND_API_KEY: ''
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
@@ -156,18 +176,18 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.7
|
||||
image: langgenius/dify-web:0.3.9
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
|
||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai
|
||||
CONSOLE_URL: ''
|
||||
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
CONSOLE_API_URL: ''
|
||||
# The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_URL: ''
|
||||
APP_API_URL: ''
|
||||
# The DSN for Sentry error reporting. If not set, Sentry error reporting will be disabled.
|
||||
SENTRY_DSN: ''
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
ENV CONSOLE_URL http://127.0.0.1:5001
|
||||
ENV APP_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_API_URL http://127.0.0.1:5001
|
||||
ENV APP_API_URL http://127.0.0.1:5001
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
|
||||
13
web/app/(shareLayout)/chatbot/[token]/page.tsx
Normal file
13
web/app/(shareLayout)/chatbot/[token]/page.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
|
||||
import type { IMainProps } from '@/app/components/share/chat'
|
||||
import Main from '@/app/components/share/chatbot'
|
||||
|
||||
const Chatbot: FC<IMainProps> = () => {
|
||||
return (
|
||||
<Main />
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Chatbot)
|
||||
233
web/app/activate/activateForm.tsx
Normal file
233
web/app/activate/activateForm.tsx
Normal file
@@ -0,0 +1,233 @@
|
||||
'use client'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useSWR from 'swr'
|
||||
import { useSearchParams } from 'next/navigation'
|
||||
import cn from 'classnames'
|
||||
import Link from 'next/link'
|
||||
import { CheckCircleIcon } from '@heroicons/react/24/solid'
|
||||
import style from './style.module.css'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
import { timezones } from '@/utils/timezone'
|
||||
import { languageMaps, languages } from '@/utils/language'
|
||||
import { activateMember, invitationCheck } from '@/service/common'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
||||
const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
|
||||
|
||||
const ActivateForm = () => {
|
||||
const { t } = useTranslation()
|
||||
const searchParams = useSearchParams()
|
||||
const workspaceID = searchParams.get('workspace_id')
|
||||
const email = searchParams.get('email')
|
||||
const token = searchParams.get('token')
|
||||
|
||||
const checkParams = {
|
||||
url: '/activate/check',
|
||||
params: {
|
||||
workspace_id: workspaceID,
|
||||
email,
|
||||
token,
|
||||
},
|
||||
}
|
||||
const { data: checkRes, mutate: recheck } = useSWR(checkParams, invitationCheck, {
|
||||
revalidateOnFocus: false,
|
||||
})
|
||||
|
||||
const [name, setName] = useState('')
|
||||
const [password, setPassword] = useState('')
|
||||
const [timezone, setTimezone] = useState('Asia/Shanghai')
|
||||
const [language, setLanguage] = useState('en-US')
|
||||
const [showSuccess, setShowSuccess] = useState(false)
|
||||
|
||||
const showErrorMessage = (message: string) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message,
|
||||
})
|
||||
}
|
||||
const valid = () => {
|
||||
if (!name.trim()) {
|
||||
showErrorMessage(t('login.error.nameEmpty'))
|
||||
return false
|
||||
}
|
||||
if (!password.trim()) {
|
||||
showErrorMessage(t('login.error.passwordEmpty'))
|
||||
return false
|
||||
}
|
||||
if (!validPassword.test(password))
|
||||
showErrorMessage(t('login.error.passwordInvalid'))
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
const handleActivate = async () => {
|
||||
if (!valid())
|
||||
return
|
||||
try {
|
||||
await activateMember({
|
||||
url: '/activate',
|
||||
body: {
|
||||
workspace_id: workspaceID,
|
||||
email,
|
||||
token,
|
||||
name,
|
||||
password,
|
||||
interface_language: language,
|
||||
timezone,
|
||||
},
|
||||
})
|
||||
setShowSuccess(true)
|
||||
}
|
||||
catch {
|
||||
recheck()
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={
|
||||
cn(
|
||||
'flex flex-col items-center w-full grow items-center justify-center',
|
||||
'px-6',
|
||||
'md:px-[108px]',
|
||||
)
|
||||
}>
|
||||
{!checkRes && <Loading/>}
|
||||
{checkRes && !checkRes.is_valid && (
|
||||
<div className="flex flex-col md:w-[400px]">
|
||||
<div className="w-full mx-auto">
|
||||
<div className="mb-3 flex justify-center items-center w-20 h-20 p-5 rounded-[20px] border border-gray-100 shadow-lg text-[40px] font-bold">🤷♂️</div>
|
||||
<h2 className="text-[32px] font-bold text-gray-900">{t('login.invalid')}</h2>
|
||||
</div>
|
||||
<div className="w-full mx-auto mt-6">
|
||||
<Button type='primary' className='w-full !fone-medium !text-sm'>
|
||||
<a href="https://dify.ai">{t('login.explore')}</a>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{checkRes && checkRes.is_valid && !showSuccess && (
|
||||
<div className='flex flex-col md:w-[400px]'>
|
||||
<div className="w-full mx-auto">
|
||||
<div className={`mb-3 flex justify-center items-center w-20 h-20 p-5 rounded-[20px] border border-gray-100 shadow-lg text-[40px] font-bold ${style.logo}`}>
|
||||
</div>
|
||||
<h2 className="text-[32px] font-bold text-gray-900">
|
||||
{`${t('login.join')} ${checkRes.workspace_name}`}
|
||||
</h2>
|
||||
<p className='mt-1 text-sm text-gray-600 '>
|
||||
{`${t('login.joinTipStart')} ${checkRes.workspace_name} ${t('login.joinTipEnd')}`}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="w-full mx-auto mt-6">
|
||||
<div className="bg-white">
|
||||
{/* username */}
|
||||
<div className='mb-5'>
|
||||
<label htmlFor="name" className="my-2 flex items-center justify-between text-sm font-medium text-gray-900">
|
||||
{t('login.name')}
|
||||
</label>
|
||||
<div className="mt-1 relative rounded-md shadow-sm">
|
||||
<input
|
||||
id="name"
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={e => setName(e.target.value)}
|
||||
placeholder={t('login.namePlaceholder') || ''}
|
||||
className={'appearance-none block w-full rounded-lg pl-[14px] px-3 py-2 border border-gray-200 hover:border-gray-300 hover:shadow-sm focus:outline-none focus:ring-primary-500 focus:border-primary-500 placeholder-gray-400 caret-primary-600 sm:text-sm pr-10'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/* password */}
|
||||
<div className='mb-5'>
|
||||
<label htmlFor="password" className="my-2 flex items-center justify-between text-sm font-medium text-gray-900">
|
||||
{t('login.password')}
|
||||
</label>
|
||||
<div className="mt-1 relative rounded-md shadow-sm">
|
||||
<input
|
||||
id="password"
|
||||
type='password'
|
||||
value={password}
|
||||
onChange={e => setPassword(e.target.value)}
|
||||
placeholder={t('login.passwordPlaceholder') || ''}
|
||||
className={'appearance-none block w-full rounded-lg pl-[14px] px-3 py-2 border border-gray-200 hover:border-gray-300 hover:shadow-sm focus:outline-none focus:ring-primary-500 focus:border-primary-500 placeholder-gray-400 caret-primary-600 sm:text-sm pr-10'}
|
||||
/>
|
||||
</div>
|
||||
<div className='mt-1 text-xs text-gray-500'>{t('login.error.passwordInvalid')}</div>
|
||||
</div>
|
||||
{/* language */}
|
||||
<div className='mb-5'>
|
||||
<label htmlFor="name" className="my-2 flex items-center justify-between text-sm font-medium text-gray-900">
|
||||
{t('login.interfaceLanguage')}
|
||||
</label>
|
||||
<div className="relative mt-1 rounded-md shadow-sm">
|
||||
<SimpleSelect
|
||||
defaultValue={languageMaps.en}
|
||||
items={languages}
|
||||
onSelect={(item) => {
|
||||
setLanguage(item.value as string)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/* timezone */}
|
||||
<div className='mb-4'>
|
||||
<label htmlFor="timezone" className="block text-sm font-medium text-gray-700">
|
||||
{t('login.timezone')}
|
||||
</label>
|
||||
<div className="relative mt-1 rounded-md shadow-sm">
|
||||
<SimpleSelect
|
||||
defaultValue={timezone}
|
||||
items={timezones}
|
||||
onSelect={(item) => {
|
||||
setTimezone(item.value as string)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<Button
|
||||
type='primary'
|
||||
className='w-full !fone-medium !text-sm'
|
||||
onClick={handleActivate}
|
||||
>
|
||||
{`${t('login.join')} ${checkRes.workspace_name}`}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="block w-hull mt-2 text-xs text-gray-600">
|
||||
{t('login.license.tip')}
|
||||
|
||||
<Link
|
||||
className='text-primary-600'
|
||||
target={'_blank'}
|
||||
href='https://docs.dify.ai/community/open-source'
|
||||
>{t('login.license.link')}</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{checkRes && checkRes.is_valid && showSuccess && (
|
||||
<div className="flex flex-col md:w-[400px]">
|
||||
<div className="w-full mx-auto">
|
||||
<div className="mb-3 flex justify-center items-center w-20 h-20 p-5 rounded-[20px] border border-gray-100 shadow-lg text-[40px] font-bold">
|
||||
<CheckCircleIcon className='w-10 h-10 text-[#039855]' />
|
||||
</div>
|
||||
<h2 className="text-[32px] font-bold text-gray-900">
|
||||
{`${t('login.activatedTipStart')} ${checkRes.workspace_name} ${t('login.activatedTipEnd')}`}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="w-full mx-auto mt-6">
|
||||
<Button type='primary' className='w-full !fone-medium !text-sm'>
|
||||
<a href="/signin">{t('login.activated')}</a>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ActivateForm
|
||||
32
web/app/activate/page.tsx
Normal file
32
web/app/activate/page.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import React from 'react'
|
||||
import cn from 'classnames'
|
||||
import Header from '../signin/_header'
|
||||
import style from '../signin/page.module.css'
|
||||
import ActivateForm from './activateForm'
|
||||
|
||||
const Activate = () => {
|
||||
return (
|
||||
<div className={cn(
|
||||
style.background,
|
||||
'flex w-full min-h-screen',
|
||||
'sm:p-4 lg:p-8',
|
||||
'gap-x-20',
|
||||
'justify-center lg:justify-start',
|
||||
)}>
|
||||
<div className={
|
||||
cn(
|
||||
'flex w-full flex-col bg-white shadow rounded-2xl shrink-0',
|
||||
'space-between',
|
||||
)
|
||||
}>
|
||||
<Header />
|
||||
<ActivateForm />
|
||||
<div className='px-8 py-6 text-sm font-normal text-gray-500'>
|
||||
© {new Date().getFullYear()} Dify, Inc. All rights reserved.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Activate
|
||||
4
web/app/activate/style.module.css
Normal file
4
web/app/activate/style.module.css
Normal file
@@ -0,0 +1,4 @@
|
||||
.logo {
|
||||
background: #fff center no-repeat url(./team-28x28.png);
|
||||
background-size: 56px;
|
||||
}
|
||||
BIN
web/app/activate/team-28x28.png
Normal file
BIN
web/app/activate/team-28x28.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.2 KiB |
@@ -65,6 +65,7 @@ export type IChatProps = {
|
||||
isShowSuggestion?: boolean
|
||||
suggestionList?: string[]
|
||||
isShowSpeechToText?: boolean
|
||||
answerIconClassName?: string
|
||||
}
|
||||
|
||||
export type MessageMore = {
|
||||
@@ -174,10 +175,11 @@ type IAnswerProps = {
|
||||
onSubmitAnnotation?: SubmitAnnotationFunc
|
||||
displayScene: DisplayScene
|
||||
isResponsing?: boolean
|
||||
answerIconClassName?: string
|
||||
}
|
||||
|
||||
// The component needs to maintain its own state to control whether to display input component
|
||||
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing }) => {
|
||||
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing, answerIconClassName }) => {
|
||||
const { id, content, more, feedback, adminFeedback, annotation: initAnnotation } = item
|
||||
const [showEdit, setShowEdit] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
@@ -292,7 +294,7 @@ const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedba
|
||||
return (
|
||||
<div key={id}>
|
||||
<div className='flex items-start'>
|
||||
<div className={`${s.answerIcon} w-10 h-10 shrink-0`}>
|
||||
<div className={`${s.answerIcon} ${answerIconClassName} w-10 h-10 shrink-0`}>
|
||||
{isResponsing
|
||||
&& <div className={s.typeingIcon}>
|
||||
<LoadingAnim type='avatar' />
|
||||
@@ -428,6 +430,7 @@ const Chat: FC<IChatProps> = ({
|
||||
isShowSuggestion,
|
||||
suggestionList,
|
||||
isShowSpeechToText,
|
||||
answerIconClassName,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
@@ -473,7 +476,7 @@ const Chat: FC<IChatProps> = ({
|
||||
}
|
||||
}
|
||||
|
||||
const haneleKeyDown = (e: any) => {
|
||||
const handleKeyDown = (e: any) => {
|
||||
isUseInputMethod.current = e.nativeEvent.isComposing
|
||||
if (e.code === 'Enter' && !e.shiftKey) {
|
||||
setQuery(query.replace(/\n$/, ''))
|
||||
@@ -520,6 +523,7 @@ const Chat: FC<IChatProps> = ({
|
||||
onSubmitAnnotation={onSubmitAnnotation}
|
||||
displayScene={displayScene ?? 'web'}
|
||||
isResponsing={isResponsing && isLast}
|
||||
answerIconClassName={answerIconClassName}
|
||||
/>
|
||||
}
|
||||
return <Question key={item.id} id={item.id} content={item.content} more={item.more} useCurrentUserAvatar={useCurrentUserAvatar} />
|
||||
@@ -573,7 +577,7 @@ const Chat: FC<IChatProps> = ({
|
||||
value={query}
|
||||
onChange={handleContentChange}
|
||||
onKeyUp={handleKeyUp}
|
||||
onKeyDown={haneleKeyDown}
|
||||
onKeyDown={handleKeyDown}
|
||||
minHeight={48}
|
||||
autoFocus
|
||||
controlFocus={controlFocus}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -372,7 +372,7 @@ const Debug: FC<IDebug> = ({
|
||||
{/* Chat */}
|
||||
{mode === AppType.chat && (
|
||||
<div className="mt-[34px] h-full flex flex-col">
|
||||
<div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[66px]'), 'relative mt-1.5 grow h-[200px] overflow-hidden')}>
|
||||
<div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[76px]'), 'relative mt-1.5 grow h-[200px] overflow-hidden')}>
|
||||
<div className="h-full overflow-y-auto overflow-x-hidden" ref={chatListDomRef}>
|
||||
<Chat
|
||||
chatList={chatList}
|
||||
|
||||
@@ -16,6 +16,7 @@ import ConfigModel from '@/app/components/app/configuration/config-model'
|
||||
import Config from '@/app/components/app/configuration/config'
|
||||
import Debug from '@/app/components/app/configuration/debug'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ProviderType } from '@/types/app'
|
||||
import type { AppDetailResponse } from '@/models/app'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchTenantInfo } from '@/service/common'
|
||||
@@ -67,7 +68,7 @@ const Configuration: FC = () => {
|
||||
frequency_penalty: 1, // -2-2
|
||||
})
|
||||
const [modelConfig, doSetModelConfig] = useState<ModelConfig>({
|
||||
provider: 'openai',
|
||||
provider: ProviderType.openai,
|
||||
model_id: 'gpt-3.5-turbo',
|
||||
configs: {
|
||||
prompt_template: '',
|
||||
@@ -84,8 +85,9 @@ const Configuration: FC = () => {
|
||||
doSetModelConfig(newModelConfig)
|
||||
}
|
||||
|
||||
const setModelId = (modelId: string) => {
|
||||
const setModelId = (modelId: string, provider: ProviderType) => {
|
||||
const newModelConfig = produce(modelConfig, (draft: any) => {
|
||||
draft.provider = provider
|
||||
draft.model_id = modelId
|
||||
})
|
||||
setModelConfig(newModelConfig)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useState } from 'react'
|
||||
import {
|
||||
Cog8ToothIcon,
|
||||
@@ -11,6 +12,7 @@ import { usePathname, useRouter } from 'next/navigation'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import SettingsModal from './settings'
|
||||
import ShareLink from './share-link'
|
||||
import EmbeddedModal from './embedded'
|
||||
import CustomizeModal from './customize'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import AppBasic, { randomString } from '@/app/components/app-sidebar/basic'
|
||||
@@ -18,6 +20,8 @@ import Button from '@/app/components/base/button'
|
||||
import Tag from '@/app/components/base/tag'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import type { AppDetailResponse } from '@/models/app'
|
||||
import './style.css'
|
||||
import { AppType } from '@/types/app'
|
||||
|
||||
export type IAppCardProps = {
|
||||
className?: string
|
||||
@@ -29,6 +33,10 @@ export type IAppCardProps = {
|
||||
onGenerateCode?: () => Promise<any>
|
||||
}
|
||||
|
||||
const EmbedIcon: FC<{ className?: string }> = ({ className = '' }) => {
|
||||
return <div className={`codeBrowserIcon ${className}`}></div>
|
||||
}
|
||||
|
||||
function AppCard({
|
||||
appInfo,
|
||||
cardType = 'app',
|
||||
@@ -42,6 +50,7 @@ function AppCard({
|
||||
const pathname = usePathname()
|
||||
const [showSettingsModal, setShowSettingsModal] = useState(false)
|
||||
const [showShareModal, setShowShareModal] = useState(false)
|
||||
const [showEmbedded, setShowEmbedded] = useState(false)
|
||||
const [showCustomizeModal, setShowCustomizeModal] = useState(false)
|
||||
const { t } = useTranslation()
|
||||
|
||||
@@ -49,8 +58,9 @@ function AppCard({
|
||||
webapp: [
|
||||
{ opName: t('appOverview.overview.appInfo.preview'), opIcon: RocketLaunchIcon },
|
||||
{ opName: t('appOverview.overview.appInfo.share.entry'), opIcon: ShareIcon },
|
||||
appInfo.mode === AppType.chat ? { opName: t('appOverview.overview.appInfo.embedded.entry'), opIcon: EmbedIcon } : false,
|
||||
{ opName: t('appOverview.overview.appInfo.settings.entry'), opIcon: Cog8ToothIcon },
|
||||
],
|
||||
].filter(item => !!item),
|
||||
api: [{ opName: t('appOverview.overview.apiInfo.doc'), opIcon: DocumentTextIcon }],
|
||||
app: [],
|
||||
}
|
||||
@@ -80,6 +90,10 @@ function AppCard({
|
||||
return () => {
|
||||
setShowSettingsModal(true)
|
||||
}
|
||||
case t('appOverview.overview.appInfo.embedded.entry'):
|
||||
return () => {
|
||||
setShowEmbedded(true)
|
||||
}
|
||||
default:
|
||||
// jump to page develop
|
||||
return () => {
|
||||
@@ -139,20 +153,20 @@ function AppCard({
|
||||
key={op.opName}
|
||||
onClick={genClickFuncByName(op.opName)}
|
||||
disabled={
|
||||
[t('appOverview.overview.appInfo.preview'), t('appOverview.overview.appInfo.share.entry')].includes(op.opName) && !runningStatus
|
||||
[t('appOverview.overview.appInfo.preview'), t('appOverview.overview.appInfo.share.entry'), t('appOverview.overview.appInfo.embedded.entry')].includes(op.opName) && !runningStatus
|
||||
}
|
||||
>
|
||||
<Tooltip
|
||||
content={t('appOverview.overview.appInfo.preUseReminder') ?? ''}
|
||||
selector={`op-btn-${randomString(16)}`}
|
||||
className={
|
||||
([t('appOverview.overview.appInfo.preview'), t('appOverview.overview.appInfo.share.entry')].includes(op.opName) && !runningStatus)
|
||||
([t('appOverview.overview.appInfo.preview'), t('appOverview.overview.appInfo.share.entry'), t('appOverview.overview.appInfo.embedded.entry')].includes(op.opName) && !runningStatus)
|
||||
? 'mt-[-8px]'
|
||||
: '!hidden'
|
||||
}
|
||||
>
|
||||
<div className="flex flex-row items-center">
|
||||
<op.opIcon className="h-4 w-4 mr-1.5" />
|
||||
<op.opIcon className="h-4 w-4 mr-1.5 stroke-[1.8px]" />
|
||||
<span className="text-xs">{op.opName}</span>
|
||||
</div>
|
||||
</Tooltip>
|
||||
@@ -193,6 +207,12 @@ function AppCard({
|
||||
onClose={() => setShowSettingsModal(false)}
|
||||
onSave={onSaveSiteConfig}
|
||||
/>
|
||||
<EmbeddedModal
|
||||
isShow={showEmbedded}
|
||||
onClose={() => setShowEmbedded(false)}
|
||||
appBaseUrl={app_base_url}
|
||||
accessToken={access_token}
|
||||
/>
|
||||
<CustomizeModal
|
||||
isShow={showCustomizeModal}
|
||||
linkUrl=""
|
||||
|
||||
3
web/app/components/app/overview/assets/code-browser.svg
Normal file
3
web/app/components/app/overview/assets/code-browser.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M14.6667 6H1.33337M9.33337 11.6667L11 10L9.33337 8.33333M6.66671 8.33333L5.00004 10L6.66671 11.6667M1.33337 5.2L1.33337 10.8C1.33337 11.9201 1.33337 12.4802 1.55136 12.908C1.74311 13.2843 2.04907 13.5903 2.42539 13.782C2.85322 14 3.41327 14 4.53337 14H11.4667C12.5868 14 13.1469 14 13.5747 13.782C13.951 13.5903 14.257 13.2843 14.4487 12.908C14.6667 12.4802 14.6667 11.9201 14.6667 10.8V5.2C14.6667 4.0799 14.6667 3.51984 14.4487 3.09202C14.257 2.7157 13.951 2.40973 13.5747 2.21799C13.1469 2 12.5868 2 11.4667 2L4.53337 2C3.41327 2 2.85322 2 2.42539 2.21799C2.04907 2.40973 1.74311 2.71569 1.55136 3.09202C1.33337 3.51984 1.33337 4.0799 1.33337 5.2Z" stroke="#344054" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 850 B |
102
web/app/components/app/overview/assets/iframe-option.svg
Normal file
102
web/app/components/app/overview/assets/iframe-option.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 43 KiB |
160
web/app/components/app/overview/assets/scripts-option.svg
Normal file
160
web/app/components/app/overview/assets/scripts-option.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 47 KiB |
111
web/app/components/app/overview/embedded/index.tsx
Normal file
111
web/app/components/app/overview/embedded/index.tsx
Normal file
@@ -0,0 +1,111 @@
|
||||
import React, { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import style from './style.module.css'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import useCopyToClipboard from '@/hooks/use-copy-to-clipboard'
|
||||
import copyStyle from '@/app/components/app/chat/copy-btn/style.module.css'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
|
||||
// const isDevelopment = process.env.NODE_ENV === 'development'
|
||||
|
||||
type Props = {
|
||||
isShow: boolean
|
||||
onClose: () => void
|
||||
accessToken: string
|
||||
appBaseUrl: string
|
||||
}
|
||||
|
||||
const OPTION_MAP = {
|
||||
iframe: {
|
||||
getContent: (url: string, token: string) =>
|
||||
`<iframe
|
||||
src="${url}/chatbot/${token}"
|
||||
style="width: 100%; height: 100%; min-height: 700px"
|
||||
frameborder="0"
|
||||
allow="microphone">
|
||||
</iframe>`,
|
||||
},
|
||||
scripts: {
|
||||
getContent: (url: string, token: string, isTestEnv?: boolean) =>
|
||||
`<script>
|
||||
window.difyChatbotConfig = { token: '${token}'${isTestEnv ? ', isDev: true' : ''} }
|
||||
</script>
|
||||
<script
|
||||
src="${url}/embed.min.js"
|
||||
id="${token}"
|
||||
defer>
|
||||
</script>`,
|
||||
},
|
||||
}
|
||||
const prefixEmbedded = 'appOverview.overview.appInfo.embedded'
|
||||
|
||||
type Option = keyof typeof OPTION_MAP
|
||||
|
||||
const Embedded = ({ isShow, onClose, appBaseUrl, accessToken }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const [option, setOption] = useState<Option>('iframe')
|
||||
const [isCopied, setIsCopied] = useState({ iframe: false, scripts: false })
|
||||
const [_, copy] = useCopyToClipboard()
|
||||
|
||||
const { langeniusVersionInfo } = useAppContext()
|
||||
const isTestEnv = langeniusVersionInfo.current_env === 'TESTING' || langeniusVersionInfo.current_env === 'DEVELOPMENT'
|
||||
const onClickCopy = () => {
|
||||
copy(OPTION_MAP[option].getContent(appBaseUrl, accessToken, isTestEnv))
|
||||
setIsCopied({ ...isCopied, [option]: true })
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t(`${prefixEmbedded}.title`)}
|
||||
isShow={isShow}
|
||||
onClose={onClose}
|
||||
className="!max-w-2xl w-[640px]"
|
||||
closable={true}
|
||||
>
|
||||
<div className="mb-4 mt-8 text-gray-900 text-[14px] font-medium leading-tight">
|
||||
{t(`${prefixEmbedded}.explanation`)}
|
||||
</div>
|
||||
<div className="flex gap-4 items-center">
|
||||
{Object.keys(OPTION_MAP).map((v, index) => {
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={cn(
|
||||
style.option,
|
||||
style[`${v}Icon`],
|
||||
option === v && style.active,
|
||||
)}
|
||||
onClick={() => setOption(v as Option)}
|
||||
></div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
<div className="mt-6 w-full bg-gray-100 rounded-lg flex-col justify-start items-start inline-flex">
|
||||
<div className="self-stretch pl-3 pr-1 py-1 bg-gray-50 rounded-tl-lg rounded-tr-lg border border-black border-opacity-5 justify-start items-center gap-2 inline-flex">
|
||||
<div className="grow shrink basis-0 text-slate-700 text-[13px] font-medium leading-none">
|
||||
{t(`${prefixEmbedded}.${option}`)}
|
||||
</div>
|
||||
<div className="p-2 rounded-lg justify-center items-center gap-1 flex">
|
||||
<Tooltip
|
||||
selector={'code-copy-feedback'}
|
||||
content={(isCopied[option] ? t(`${prefixEmbedded}.copied`) : t(`${prefixEmbedded}.copy`)) || ''}
|
||||
>
|
||||
<div className="w-8 h-8 cursor-pointer hover:bg-gray-100 rounded-lg">
|
||||
<div onClick={onClickCopy} className={`w-full h-full ${copyStyle.copyIcon} ${isCopied[option] ? copyStyle.copied : ''}`}></div>
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
<div className="self-stretch p-3 justify-start items-start gap-2 inline-flex">
|
||||
<div className="grow shrink basis-0 text-slate-700 text-[13px] leading-tight font-mono">
|
||||
<pre className='select-text'>{OPTION_MAP[option].getContent(appBaseUrl, accessToken, isTestEnv)}</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
export default Embedded
|
||||
14
web/app/components/app/overview/embedded/style.module.css
Normal file
14
web/app/components/app/overview/embedded/style.module.css
Normal file
@@ -0,0 +1,14 @@
|
||||
.option {
|
||||
width: 188px;
|
||||
height: 128px;
|
||||
@apply box-border cursor-pointer bg-auto bg-no-repeat bg-center rounded-md;
|
||||
}
|
||||
.active {
|
||||
@apply border-[1.5px] border-[#2970FF];
|
||||
}
|
||||
.iframeIcon {
|
||||
background-image: url(../assets/iframe-option.svg);
|
||||
}
|
||||
.scriptsIcon {
|
||||
background-image: url(../assets/scripts-option.svg);
|
||||
}
|
||||
@@ -11,3 +11,8 @@
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.codeBrowserIcon {
|
||||
@apply w-4 h-4 bg-center bg-no-repeat;
|
||||
background-image: url(./assets/code-browser.svg);
|
||||
}
|
||||
|
||||
@@ -128,6 +128,9 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
startQuerying()
|
||||
const res: any = await fetchMoreLikeThis(messageId as string, isInstalledApp, installedAppId)
|
||||
setCompletionRes(res.answer)
|
||||
setChildFeedback({
|
||||
rating: null,
|
||||
})
|
||||
setChildMessageId(res.id)
|
||||
stopQuerying()
|
||||
}
|
||||
@@ -152,6 +155,12 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
}
|
||||
}, [controlClearMoreLikeThis])
|
||||
|
||||
// regeneration clear child
|
||||
useEffect(() => {
|
||||
if (isLoading)
|
||||
setChildMessageId(null)
|
||||
}, [isLoading])
|
||||
|
||||
return (
|
||||
<div className={cn(className, isTop ? 'rounded-xl border border-gray-200 bg-white' : 'rounded-br-xl !mt-0')}
|
||||
style={isTop
|
||||
@@ -175,7 +184,11 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
{taskId}
|
||||
</div>)
|
||||
}
|
||||
<Markdown content={content} />
|
||||
<div className='flex'>
|
||||
<div className='grow w-0'>
|
||||
<Markdown content={content} />
|
||||
</div>
|
||||
</div>
|
||||
{messageId && (
|
||||
<div className='flex items-center justify-between mt-3'>
|
||||
<div className='flex items-center'>
|
||||
|
||||
@@ -19,6 +19,7 @@ const AutoHeightTextarea = forwardRef(
|
||||
{ value, onChange, placeholder, className, minHeight = 36, maxHeight = 96, autoFocus, controlFocus, onKeyDown, onKeyUp }: IProps,
|
||||
outerRef: any,
|
||||
) => {
|
||||
// eslint-disable-next-line react-hooks/rules-of-hooks
|
||||
const ref = outerRef || useRef<HTMLTextAreaElement>(null)
|
||||
|
||||
const doFocus = () => {
|
||||
@@ -54,13 +55,20 @@ const AutoHeightTextarea = forwardRef(
|
||||
|
||||
return (
|
||||
<div className='relative'>
|
||||
<div className={cn(className, 'invisible whitespace-pre-wrap break-all overflow-y-auto')} style={{ minHeight, maxHeight }}>
|
||||
<div className={cn(className, 'invisible whitespace-pre-wrap break-all overflow-y-auto')} style={{
|
||||
minHeight,
|
||||
maxHeight,
|
||||
paddingRight: (value && value.trim().length > 10000) ? 140 : 130,
|
||||
}}>
|
||||
{!value ? placeholder : value.replace(/\n$/, '\n ')}
|
||||
</div>
|
||||
<textarea
|
||||
ref={ref}
|
||||
autoFocus={autoFocus}
|
||||
className={cn(className, 'absolute inset-0 resize-none overflow-hidden')}
|
||||
className={cn(className, 'absolute inset-0 resize-none overflow-auto')}
|
||||
style={{
|
||||
paddingRight: (value && value.trim().length > 10000) ? 140 : 130,
|
||||
}}
|
||||
placeholder={placeholder}
|
||||
onChange={onChange}
|
||||
onKeyDown={onKeyDown}
|
||||
|
||||
@@ -11,7 +11,7 @@ export const RFC_LOCALES = [
|
||||
{ value: 'en-US', name: 'EN' },
|
||||
{ value: 'zh-Hans', name: '简体中文' },
|
||||
]
|
||||
interface ISelectProps {
|
||||
type ISelectProps = {
|
||||
items: Array<{ value: string; name: string }>
|
||||
value?: string
|
||||
className?: string
|
||||
@@ -21,7 +21,7 @@ interface ISelectProps {
|
||||
export default function Select({
|
||||
items,
|
||||
value,
|
||||
onChange
|
||||
onChange,
|
||||
}: ISelectProps) {
|
||||
const item = items.filter(item => item.value === value)[0]
|
||||
|
||||
@@ -29,11 +29,12 @@ export default function Select({
|
||||
<div className="w-56 text-right">
|
||||
<Menu as="div" className="relative inline-block text-left">
|
||||
<div>
|
||||
<Menu.Button className="inline-flex w-full justify-center items-center
|
||||
rounded-lg px-2 py-1
|
||||
text-gray-600 text-xs font-medium
|
||||
border border-gray-200">
|
||||
<GlobeAltIcon className="w-5 h-5 mr-2 " aria-hidden="true" />
|
||||
<Menu.Button className="inline-flex w-full h-[44px]justify-center items-center
|
||||
rounded-lg px-[10px] py-[6px]
|
||||
text-gray-900 text-[13px] font-medium
|
||||
border border-gray-200
|
||||
hover:bg-gray-100">
|
||||
<GlobeAltIcon className="w-5 h-5 mr-1" aria-hidden="true" />
|
||||
{item?.name}
|
||||
</Menu.Button>
|
||||
</div>
|
||||
@@ -46,14 +47,14 @@ export default function Select({
|
||||
leaveFrom="transform opacity-100 scale-100"
|
||||
leaveTo="transform opacity-0 scale-95"
|
||||
>
|
||||
<Menu.Items className="absolute right-0 mt-2 w-28 origin-top-right divide-y divide-gray-100 rounded-md bg-white shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none">
|
||||
<Menu.Items className="absolute right-0 mt-2 w-[120px] origin-top-right divide-y divide-gray-100 rounded-md bg-white shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none">
|
||||
<div className="px-1 py-1 ">
|
||||
{items.map((item) => {
|
||||
return <Menu.Item key={item.value}>
|
||||
{({ active }) => (
|
||||
<button
|
||||
className={`${active ? 'bg-gray-100' : ''
|
||||
} group flex w-full items-center rounded-md px-2 py-2 text-sm`}
|
||||
} group flex w-full items-center rounded-lg px-3 py-2 text-sm text-gray-700`}
|
||||
onClick={(evt) => {
|
||||
evt.preventDefault()
|
||||
onChange && onChange(item.value)
|
||||
@@ -77,7 +78,7 @@ export default function Select({
|
||||
export function InputSelect({
|
||||
items,
|
||||
value,
|
||||
onChange
|
||||
onChange,
|
||||
}: ISelectProps) {
|
||||
const item = items.filter(item => item.value === value)[0]
|
||||
return (
|
||||
@@ -104,7 +105,7 @@ export function InputSelect({
|
||||
{({ active }) => (
|
||||
<button
|
||||
className={`${active ? 'bg-gray-100' : ''
|
||||
} group flex w-full items-center rounded-md px-2 py-2 text-sm`}
|
||||
} group flex w-full items-center rounded-md px-2 py-2 text-sm`}
|
||||
onClick={() => {
|
||||
onChange && onChange(item.value)
|
||||
}}
|
||||
@@ -122,4 +123,4 @@ export function InputSelect({
|
||||
</Menu>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useParams, usePathname } from 'next/navigation'
|
||||
import cn from 'classnames'
|
||||
import Recorder from 'js-audio-recorder'
|
||||
import { useRafInterval } from 'ahooks'
|
||||
import { convertToMp3 } from './utils'
|
||||
import s from './index.module.css'
|
||||
import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
|
||||
import { Loading02, XClose } from '@/app/components/base/icons/src/vender/line/general'
|
||||
@@ -19,7 +20,12 @@ const VoiceInput = ({
|
||||
onConverted,
|
||||
}: VoiceInputTypes) => {
|
||||
const { t } = useTranslation()
|
||||
const recorder = useRef(new Recorder())
|
||||
const recorder = useRef(new Recorder({
|
||||
sampleBits: 16,
|
||||
sampleRate: 16000,
|
||||
numChannels: 1,
|
||||
compiling: false,
|
||||
}))
|
||||
const canvasRef = useRef<HTMLCanvasElement | null>(null)
|
||||
const ctxRef = useRef<CanvasRenderingContext2D | null>(null)
|
||||
const drawRecordId = useRef<number | null>(null)
|
||||
@@ -75,10 +81,10 @@ const VoiceInput = ({
|
||||
const canvas = canvasRef.current!
|
||||
const ctx = ctxRef.current!
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height)
|
||||
const wavBlob = recorder.current.getWAVBlob()
|
||||
const wavFile = new File([wavBlob], 'a.wav', { type: 'audio/wav' })
|
||||
const mp3Blob = convertToMp3(recorder.current)
|
||||
const mp3File = new File([mp3Blob], 'temp.mp3', { type: 'audio/mp3' })
|
||||
const formData = new FormData()
|
||||
formData.append('file', wavFile)
|
||||
formData.append('file', mp3File)
|
||||
|
||||
let url = ''
|
||||
let isPublic = false
|
||||
|
||||
38
web/app/components/base/voice-input/utils.ts
Normal file
38
web/app/components/base/voice-input/utils.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import lamejs from 'lamejs'
|
||||
|
||||
export const convertToMp3 = (recorder: any) => {
|
||||
const wav = lamejs.WavHeader.readHeader(recorder.getWAV())
|
||||
const { channels, sampleRate } = wav
|
||||
const mp3enc = new lamejs.Mp3Encoder(channels, sampleRate, 128)
|
||||
const result = recorder.getChannelData()
|
||||
const buffer = []
|
||||
|
||||
const leftData = result.left && new Int16Array(result.left.buffer, 0, result.left.byteLength / 2)
|
||||
const rightData = result.right && new Int16Array(result.right.buffer, 0, result.right.byteLength / 2)
|
||||
const remaining = leftData.length + (rightData ? rightData.length : 0)
|
||||
|
||||
const maxSamples = 1152
|
||||
for (let i = 0; i < remaining; i += maxSamples) {
|
||||
const left = leftData.subarray(i, i + maxSamples)
|
||||
let right = null
|
||||
let mp3buf = null
|
||||
|
||||
if (channels === 2) {
|
||||
right = rightData.subarray(i, i + maxSamples)
|
||||
mp3buf = mp3enc.encodeBuffer(left, right)
|
||||
}
|
||||
else {
|
||||
mp3buf = mp3enc.encodeBuffer(left)
|
||||
}
|
||||
|
||||
if (mp3buf.length > 0)
|
||||
buffer.push(mp3buf)
|
||||
}
|
||||
|
||||
const enc = mp3enc.flush()
|
||||
|
||||
if (enc.length > 0)
|
||||
buffer.push(enc)
|
||||
|
||||
return new Blob(buffer, { type: 'audio/mp3' })
|
||||
}
|
||||
@@ -138,7 +138,6 @@ For high-quality text generation, such as articles, summaries, and translations,
|
||||
"avatar_url": "https://assets.protocol.chat/avatars/frank.jpg",
|
||||
"display_name": null,
|
||||
"conversation_id": "xgQQXg3hrtjh7AvZ",
|
||||
"last_active_at": 705103200,
|
||||
"created_at": 692233200
|
||||
},
|
||||
{
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user