mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 16:59:21 +08:00
Feat/dataset notion import (#392)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
# redis configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_USERNAME: ''
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_DB=0
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import model, account, dataset, web, task
|
||||
from models import model, account, dataset, web, task, source
|
||||
from events import event_handlers
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
@@ -187,6 +187,9 @@ class Config:
|
||||
# For temp use only
|
||||
# set default LLM provider, default is 'openai', support `azure_openai`
|
||||
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ from . import setup, version, apikey, admin
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth
|
||||
from .auth import login, oauth, data_source_oauth
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, providers, account
|
||||
|
||||
95
api/controllers/console/auth/data_source_oauth.py
Normal file
95
api/controllers/console/auth/data_source_oauth.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
class OAuthDataSource(Resource):
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=access_denied')
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, binding_id):
|
||||
provider = str(provider)
|
||||
binding_id = str(binding_id)
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
||||
303
api/controllers/console/datasets/data_source.py
Normal file
303
api/controllers/console/datasets/data_source.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_source.notion import NotionPageReader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||
'total': fields.Integer
|
||||
}
|
||||
integrate_fields = {
|
||||
'id': fields.String,
|
||||
'provider': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'is_bound': fields.Boolean,
|
||||
'disabled': fields.Boolean,
|
||||
'link': fields.String,
|
||||
'source_info': fields.Nested(integrate_workspace_fields)
|
||||
}
|
||||
integrate_list_fields = {
|
||||
'data': fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.query(DataSourceBinding).filter(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.disabled == False
|
||||
).all()
|
||||
|
||||
base_url = request.url_root.rstrip('/')
|
||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||
providers = ["notion"]
|
||||
|
||||
integrate_data = []
|
||||
for provider in providers:
|
||||
# existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None)
|
||||
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
|
||||
if existing_integrates:
|
||||
for existing_integrate in list(existing_integrates):
|
||||
integrate_data.append({
|
||||
'id': existing_integrate.id,
|
||||
'provider': provider,
|
||||
'created_at': existing_integrate.created_at,
|
||||
'is_bound': True,
|
||||
'disabled': existing_integrate.disabled,
|
||||
'source_info': existing_integrate.source_info,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
else:
|
||||
integrate_data.append({
|
||||
'id': None,
|
||||
'provider': provider,
|
||||
'created_at': None,
|
||||
'source_info': None,
|
||||
'is_bound': False,
|
||||
'disabled': None,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
return {'data': integrate_data}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceBinding.query.filter_by(
|
||||
id=binding_id
|
||||
).first()
|
||||
if data_source_binding is None:
|
||||
raise NotFound('Data source binding not found.')
|
||||
# enable binding
|
||||
if action == 'enable':
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source is not disabled.')
|
||||
# disable binding
|
||||
if action == 'disable':
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source is disabled.')
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DataSourceNotionListApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'is_bound': fields.Boolean,
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||
}
|
||||
integrate_notion_info_list_fields = {
|
||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get('dataset_id', default=None, type=str)
|
||||
exist_page_ids = []
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
if dataset.data_source_type != 'notion_import':
|
||||
raise ValueError('Dataset is not notion type.')
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='notion_import',
|
||||
enabled=True
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='notion',
|
||||
disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {
|
||||
'notion_info': []
|
||||
}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info['pages']
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page['page_id'] in exist_page_ids:
|
||||
page['is_bound'] = True
|
||||
else:
|
||||
page['is_bound'] = False
|
||||
pre_import_info = {
|
||||
'workspace_name': source_info['workspace_name'],
|
||||
'workspace_icon': source_info['workspace_icon'],
|
||||
'workspace_id': source_info['workspace_id'],
|
||||
'pages': pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {
|
||||
'notion_info': pre_import_info_list
|
||||
}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
||||
if page_type == 'page':
|
||||
page_content = reader.read_page(page_id)
|
||||
elif page_type == 'database':
|
||||
page_content = reader.query_database_data(page_id)
|
||||
else:
|
||||
page_content = ""
|
||||
return {
|
||||
'content': page_content
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
|
||||
return response, 200
|
||||
|
||||
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
|
||||
for document in documents:
|
||||
document_indexing_sync_task.delay(dataset_id_str, document.id)
|
||||
return 200
|
||||
|
||||
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if document is None:
|
||||
raise NotFound("Document not found.")
|
||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
|
||||
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
|
||||
api.add_resource(DataSourceNotionApi,
|
||||
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
|
||||
'/datasets/notion-indexing-estimate')
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
|
||||
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
|
||||
@@ -12,8 +12,9 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@@ -217,17 +218,31 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
segment_rule = request.get_json()
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == segment_rule["file_id"]
|
||||
).first()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(file_ids)
|
||||
).all()
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule'])
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
|
||||
args['process_rule'])
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
||||
|
||||
@@ -274,8 +289,54 @@ class DatasetRelatedAppListApi(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
@@ -61,6 +62,29 @@ document_fields = {
|
||||
'hit_count': fields.Integer,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
@@ -83,6 +107,23 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||
|
||||
if not documents:
|
||||
raise NotFound('Documents not found.')
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@@ -132,9 +173,9 @@ class DatasetDocumentListApi(Resource):
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
search = request.args.get('search', default=None, type=str)
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
sort = request.args.get('sort', default='-created_at', type=str)
|
||||
|
||||
fetch = request.args.get('fetch', default=False, type=bool)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@@ -173,9 +214,20 @@ class DatasetDocumentListApi(Resource):
|
||||
paginated_documents = query.paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
else:
|
||||
data = marshal(documents, document_fields)
|
||||
response = {
|
||||
'data': marshal(documents, document_fields),
|
||||
'data': data,
|
||||
'has_more': len(documents) == limit,
|
||||
'limit': limit,
|
||||
'total': paginated_documents.total,
|
||||
@@ -184,10 +236,15 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
return response
|
||||
|
||||
documents_and_batch_fields = {
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@@ -221,7 +278,7 @@ class DatasetDocumentListApi(Resource):
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
@@ -229,13 +286,17 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return document
|
||||
return {
|
||||
'documents': documents,
|
||||
'batch': batch
|
||||
}
|
||||
|
||||
|
||||
class DatasetInitApi(Resource):
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'document': fields.Nested(document_fields)
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@@ -258,7 +319,7 @@ class DatasetInitApi(Resource):
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
dataset, document = DocumentService.save_document_without_dataset_id(
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
document_data=args,
|
||||
account=current_user
|
||||
@@ -272,7 +333,8 @@ class DatasetInitApi(Resource):
|
||||
|
||||
response = {
|
||||
'dataset': dataset,
|
||||
'document': document
|
||||
'documents': documents,
|
||||
'batch': batch
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -317,11 +379,122 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound('File not found.')
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
|
||||
|
||||
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
response = {
|
||||
"tokens": 0,
|
||||
"total_price": 0,
|
||||
"currency": "USD",
|
||||
"total_segments": 0,
|
||||
"preview": []
|
||||
}
|
||||
if not documents:
|
||||
return response
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
# format document files info
|
||||
if data_source_info and 'upload_file_id' in data_source_info:
|
||||
file_id = data_source_info['upload_file_id']
|
||||
info_list.append(file_id)
|
||||
# format document notion info
|
||||
elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info:
|
||||
pages = []
|
||||
page = {
|
||||
'page_id': data_source_info['notion_page_id'],
|
||||
'type': data_source_info['type']
|
||||
}
|
||||
pages.append(page)
|
||||
notion_info = {
|
||||
'workspace_id': data_source_info['notion_workspace_id'],
|
||||
'pages': pages
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id in info_list
|
||||
).all()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
|
||||
elif dataset.data_source_type:
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(info_list,
|
||||
data_process_rule_dict)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
@@ -408,7 +581,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
'disabled_by': document.disabled_by,
|
||||
'archived': document.archived,
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
}
|
||||
@@ -428,7 +601,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
'created_at': document.created_at.timestamp(),
|
||||
'tokens': document.tokens,
|
||||
'indexing_status': document.indexing_status,
|
||||
'completed_at': int(document.completed_at.timestamp())if document.completed_at else None,
|
||||
'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
'indexing_latency': document.indexing_latency,
|
||||
'error': document.error,
|
||||
@@ -579,6 +752,8 @@ class DocumentStatusApi(DocumentResource):
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
elif action == "disable":
|
||||
if not document.completed_at or document.indexing_status != 'completed':
|
||||
raise InvalidActionError('Document is not completed.')
|
||||
if not document.enabled:
|
||||
raise InvalidActionError('Document already disabled.')
|
||||
|
||||
@@ -678,6 +853,10 @@ api.add_resource(DatasetInitApi,
|
||||
'/datasets/init')
|
||||
api.add_resource(DocumentIndexingEstimateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate')
|
||||
api.add_resource(DocumentBatchIndexingEstimateApi,
|
||||
'/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate')
|
||||
api.add_resource(DocumentBatchIndexingStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status')
|
||||
api.add_resource(DocumentIndexingStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status')
|
||||
api.add_resource(DocumentDetailApi,
|
||||
|
||||
@@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource):
|
||||
document_data = {
|
||||
'data_source': {
|
||||
'type': 'upload_file',
|
||||
'info': upload_file.id
|
||||
'info': [
|
||||
{
|
||||
'upload_file_id': upload_file.id
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
document = DocumentService.save_document_with_dataset_id(
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=document_data,
|
||||
account=dataset.created_by_account,
|
||||
@@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource):
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
|
||||
367
api/core/data_source/notion.py
Normal file
367
api/core/data_source/notion.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""Notion reader."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.readers.schema.base import Document
|
||||
|
||||
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Notion DB reader coming soon!
|
||||
class NotionPageReader(BaseReader):
|
||||
"""Notion Page reader.
|
||||
|
||||
Reads a set of Notion pages.
|
||||
|
||||
Args:
|
||||
integration_token (str): Notion integration token.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, integration_token: Optional[str] = None) -> None:
|
||||
"""Initialize with parameters."""
|
||||
if integration_token is None:
|
||||
integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
self.token = integration_token
|
||||
self.headers = {
|
||||
"Authorization": "Bearer " + self.token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET", block_url, headers=self.headers, json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
if 'results' not in data or data["results"] is None:
|
||||
done = True
|
||||
break
|
||||
heading = ''
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
result_lines_arr.append(text)
|
||||
else:
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
# skip if doesn't have text object
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
prefix = "\t" * num_tabs
|
||||
cur_result_text_arr.append(prefix + text)
|
||||
if result_type in HEADING_TYPE:
|
||||
heading = text
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
if has_children:
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=num_tabs + 1
|
||||
)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_TYPE:
|
||||
result_lines_arr.append(cur_result_text)
|
||||
else:
|
||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def _read_table_rows(self, block_id: str) -> str:
|
||||
"""Read table rows."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET", block_url, headers=self.headers, json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
# get table headers text
|
||||
table_header_cell_texts = []
|
||||
tabel_header_cells = data["results"][0]['table_row']['cells']
|
||||
for tabel_header_cell in tabel_header_cells:
|
||||
if tabel_header_cell:
|
||||
for table_header_cell_text in tabel_header_cell:
|
||||
text = table_header_cell_text["text"]["content"]
|
||||
table_header_cell_texts.append(text)
|
||||
# get table columns text and format
|
||||
results = data["results"]
|
||||
for i in range(len(results)-1):
|
||||
column_texts = []
|
||||
tabel_column_cells = data["results"][i+1]['table_row']['cells']
|
||||
for j in range(len(tabel_column_cells)):
|
||||
if tabel_column_cells[j]:
|
||||
for table_column_cell_text in tabel_column_cells[j]:
|
||||
column_text = table_column_cell_text["text"]["content"]
|
||||
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
|
||||
|
||||
cur_result_text = "\n".join(column_texts)
|
||||
result_lines_arr.append(cur_result_text)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
|
||||
"""Read a block."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET", block_url, headers=self.headers, json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
# current block's heading
|
||||
heading = ''
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
text += "\n\n"
|
||||
result_lines_arr.append(text)
|
||||
else:
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
# skip if doesn't have text object
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
if result_type in HEADING_TYPE:
|
||||
heading = text
|
||||
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
if has_children:
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=num_tabs + 1
|
||||
)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
cur_result_text += "\n\n"
|
||||
if result_type in HEADING_TYPE:
|
||||
result_lines_arr.append(cur_result_text)
|
||||
else:
|
||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
return result_lines_arr
|
||||
|
||||
def read_page(self, page_id: str) -> str:
|
||||
"""Read a page."""
|
||||
return self._read_block(page_id)
|
||||
|
||||
def read_page_as_documents(self, page_id: str) -> List[str]:
|
||||
"""Read a page as documents."""
|
||||
return self._read_parent_blocks(page_id)
|
||||
|
||||
def query_database_data(
|
||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
||||
) -> str:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post\
|
||||
(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers=self.headers,
|
||||
json=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
database_content_list = []
|
||||
if 'results' not in data or data["results"] is None:
|
||||
return ""
|
||||
for result in data["results"]:
|
||||
properties = result['properties']
|
||||
data = {}
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value['type']
|
||||
if type == 'multi_select':
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select['name'])
|
||||
elif type == 'rich_text' or type == 'title':
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]['plain_text']
|
||||
else:
|
||||
value = ''
|
||||
elif type == 'select' or type == 'status':
|
||||
if property_value[type]:
|
||||
value = property_value[type]['name']
|
||||
else:
|
||||
value = ''
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
database_content_list.append(json.dumps(data))
|
||||
|
||||
return "\n\n".join(database_content_list)
|
||||
|
||||
def query_database(
|
||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
||||
) -> List[str]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post\
|
||||
(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers=self.headers,
|
||||
json=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
page_ids = []
|
||||
for result in data["results"]:
|
||||
page_id = result["id"]
|
||||
page_ids.append(page_id)
|
||||
|
||||
return page_ids
|
||||
|
||||
def search(self, query: str) -> List[str]:
|
||||
"""Search Notion page given a text query."""
|
||||
done = False
|
||||
next_cursor: Optional[str] = None
|
||||
page_ids = []
|
||||
while not done:
|
||||
query_dict = {
|
||||
"query": query,
|
||||
}
|
||||
if next_cursor is not None:
|
||||
query_dict["start_cursor"] = next_cursor
|
||||
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
|
||||
data = res.json()
|
||||
for result in data["results"]:
|
||||
page_id = result["id"]
|
||||
page_ids.append(page_id)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
next_cursor = data["next_cursor"]
|
||||
return page_ids
|
||||
|
||||
def load_data(
|
||||
self, page_ids: List[str] = [], database_id: Optional[str] = None
|
||||
) -> List[Document]:
|
||||
"""Load data from the input directory.
|
||||
|
||||
Args:
|
||||
page_ids (List[str]): List of page ids to load.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of documents.
|
||||
|
||||
"""
|
||||
if not page_ids and not database_id:
|
||||
raise ValueError("Must specify either `page_ids` or `database_id`.")
|
||||
docs = []
|
||||
if database_id is not None:
|
||||
# get all the pages in the database
|
||||
page_ids = self.query_database(database_id)
|
||||
for page_id in page_ids:
|
||||
page_text = self.read_page(page_id)
|
||||
docs.append(Document(page_text))
|
||||
else:
|
||||
for page_id in page_ids:
|
||||
page_text = self.read_page(page_id)
|
||||
docs.append(Document(page_text))
|
||||
|
||||
return docs
|
||||
|
||||
def load_data_as_documents(
|
||||
self, page_ids: List[str] = [], database_id: Optional[str] = None
|
||||
) -> List[Document]:
|
||||
if not page_ids and not database_id:
|
||||
raise ValueError("Must specify either `page_ids` or `database_id`.")
|
||||
docs = []
|
||||
if database_id is not None:
|
||||
# get all the pages in the database
|
||||
page_text = self.query_database_data(database_id)
|
||||
docs.append(Document(page_text))
|
||||
else:
|
||||
for page_id in page_ids:
|
||||
page_text_list = self.read_page_as_documents(page_id)
|
||||
for page_text in page_text_list:
|
||||
docs.append(Document(page_text))
|
||||
|
||||
return docs
|
||||
|
||||
def get_page_last_edited_time(self, page_id: str) -> str:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET", retrieve_page_url, headers=self.headers, json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
return data["last_edited_time"]
|
||||
|
||||
def get_database_last_edited_time(self, database_id: str) -> str:
|
||||
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET", retrieve_page_url, headers=self.headers, json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
return data["last_edited_time"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
reader = NotionPageReader()
|
||||
logger.info(reader.search("What I"))
|
||||
@@ -5,6 +5,8 @@ import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from flask_login import current_user
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from llama_index import SimpleDirectoryReader
|
||||
@@ -13,6 +15,8 @@ from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.node_parser import SimpleNodeParser, NodeParser
|
||||
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
|
||||
from llama_index.readers.file.markdown_parser import MarkdownParser
|
||||
|
||||
from core.data_source.notion import NotionPageReader
|
||||
from core.index.readers.xlsx_parser import XLSXParser
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
@@ -27,6 +31,7 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
|
||||
|
||||
class IndexingRunner:
|
||||
@@ -35,42 +40,43 @@ class IndexingRunner:
|
||||
self.storage = storage
|
||||
self.embedding_model_name = embedding_model_name
|
||||
|
||||
def run(self, document: Document):
|
||||
def run(self, documents: List[Document]):
|
||||
"""Run the indexing process."""
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=document.dataset_id
|
||||
).first()
|
||||
for document in documents:
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=document.dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(document)
|
||||
# load file
|
||||
text_docs = self._load_data(document)
|
||||
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
|
||||
first()
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._step_split(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
# split to nodes
|
||||
nodes = self._step_split(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
nodes=nodes
|
||||
)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
nodes=nodes
|
||||
)
|
||||
|
||||
def run_in_splitting_status(self, document: Document):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
@@ -164,38 +170,98 @@ class IndexingRunner:
|
||||
nodes=nodes
|
||||
)
|
||||
|
||||
def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict:
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# load data from file
|
||||
text_docs = self._load_data_from_file(file_detail)
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._split_to_nodes(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
for node in nodes:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(node.get_text())
|
||||
total_segments = 0
|
||||
for file_detail in file_details:
|
||||
# load data from file
|
||||
text_docs = self._load_data_from_file(file_detail)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._split_to_nodes(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
total_segments += len(nodes)
|
||||
for node in nodes:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(node.get_text())
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
||||
|
||||
return {
|
||||
"total_segments": len(nodes),
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
||||
for page in notion_info['pages']:
|
||||
if page['type'] == 'page':
|
||||
page_ids = [page['page_id']]
|
||||
documents = reader.load_data_as_documents(page_ids=page_ids)
|
||||
elif page['type'] == 'database':
|
||||
documents = reader.load_data_as_documents(database_id=page['page_id'])
|
||||
else:
|
||||
documents = []
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._split_to_nodes(
|
||||
text_docs=documents,
|
||||
node_parser=node_parser,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
total_segments += len(nodes)
|
||||
for node in nodes:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(node.get_text())
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
@@ -204,25 +270,50 @@ class IndexingRunner:
|
||||
|
||||
def _load_data(self, document: Document) -> List[Document]:
|
||||
# load file
|
||||
if document.data_source_type != "upload_file":
|
||||
if document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
return []
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if not data_source_info or 'upload_file_id' not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
text_docs = []
|
||||
if document.data_source_type == 'upload_file':
|
||||
if not data_source_info or 'upload_file_id' not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
|
||||
file_detail = db.session.query(UploadFile). \
|
||||
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
||||
one_or_none()
|
||||
|
||||
text_docs = self._load_data_from_file(file_detail)
|
||||
file_detail = db.session.query(UploadFile). \
|
||||
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
||||
one_or_none()
|
||||
|
||||
text_docs = self._load_data_from_file(file_detail)
|
||||
elif document.data_source_type == 'notion_import':
|
||||
if not data_source_info or 'notion_page_id' not in data_source_info \
|
||||
or 'notion_workspace_id' not in data_source_info:
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info['notion_workspace_id']
|
||||
page_id = data_source_info['notion_page_id']
|
||||
page_type = data_source_info['type']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == document.tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
if page_type == 'page':
|
||||
# add page last_edited_time to data_source_info
|
||||
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
|
||||
text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
|
||||
elif page_type == 'database':
|
||||
# add page last_edited_time to data_source_info
|
||||
self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
|
||||
text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=document.id,
|
||||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
Document.file_id: file_detail.id,
|
||||
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
|
||||
Document.parsing_completed_at: datetime.datetime.utcnow()
|
||||
}
|
||||
@@ -259,6 +350,41 @@ class IndexingRunner:
|
||||
|
||||
return text_docs
|
||||
|
||||
def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
|
||||
page_ids = [page_id]
|
||||
reader = NotionPageReader(integration_token=access_token)
|
||||
text_docs = reader.load_data_as_documents(page_ids=page_ids)
|
||||
return text_docs
|
||||
|
||||
def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
|
||||
reader = NotionPageReader(integration_token=access_token)
|
||||
text_docs = reader.load_data_as_documents(database_id=database_id)
|
||||
return text_docs
|
||||
|
||||
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
|
||||
reader = NotionPageReader(integration_token=access_token)
|
||||
last_edited_time = reader.get_page_last_edited_time(page_id)
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info['last_edited_time'] = last_edited_time
|
||||
update_params = {
|
||||
Document.data_source_info: json.dumps(data_source_info)
|
||||
}
|
||||
|
||||
Document.query.filter_by(id=document.id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
|
||||
reader = NotionPageReader(integration_token=access_token)
|
||||
last_edited_time = reader.get_database_last_edited_time(page_id)
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info['last_edited_time'] = last_edited_time
|
||||
update_params = {
|
||||
Document.data_source_info: json.dumps(data_source_info)
|
||||
}
|
||||
|
||||
Document.query.filter_by(id=document.id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
@@ -308,7 +434,7 @@ class IndexingRunner:
|
||||
embedding_model_name=self.embedding_model_name,
|
||||
document_id=document.id
|
||||
)
|
||||
|
||||
# add document segments
|
||||
doc_store.add_documents(nodes)
|
||||
|
||||
# update document status to indexing
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import json
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceBinding
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -134,3 +139,5 @@ class GoogleOAuth(OAuth):
|
||||
name=None,
|
||||
email=raw_info['email']
|
||||
)
|
||||
|
||||
|
||||
|
||||
256
api/libs/oauth_data_source.py
Normal file
256
api/libs/oauth_data_source.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceBinding
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NotionOAuth(OAuthDataSource):
|
||||
_AUTH_URL = 'https://api.notion.com/v1/oauth/authorize'
|
||||
_TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
|
||||
_NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'owner': 'user'
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
'code': code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.redirect_uri
|
||||
}
|
||||
headers = {'Accept': 'application/json'}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get('access_token')
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in Notion OAuth: {response_json}")
|
||||
workspace_name = response_json.get('workspace_name')
|
||||
workspace_icon = response_json.get('workspace_icon')
|
||||
workspace_id = response_json.get('workspace_id')
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
'workspace_name': workspace_name,
|
||||
'workspace_icon': workspace_icon,
|
||||
'workspace_id': workspace_id,
|
||||
'pages': pages,
|
||||
'total': len(pages)
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.access_token == access_token
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
provider='notion'
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.id == binding_id,
|
||||
DataSourceBinding.disabled == False
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
'workspace_name': source_info['workspace_name'],
|
||||
'workspace_icon': source_info['workspace_icon'],
|
||||
'workspace_id': source_info['workspace_id'],
|
||||
'pages': pages,
|
||||
'total': len(pages)
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
data_source_binding.disabled = False
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source binding not found')
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
for page_result in page_results:
|
||||
page_id = page_result['id']
|
||||
if 'Name' in page_result['properties']:
|
||||
if len(page_result['properties']['Name']['title']) > 0:
|
||||
page_name = page_result['properties']['Name']['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
elif 'title' in page_result['properties']:
|
||||
if len(page_result['properties']['title']['title']) > 0:
|
||||
page_name = page_result['properties']['title']['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
elif 'Title' in page_result['properties']:
|
||||
if len(page_result['properties']['Title']['title']) > 0:
|
||||
page_name = page_result['properties']['Title']['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
page_icon = page_result['icon']
|
||||
if page_icon:
|
||||
icon_type = page_icon['type']
|
||||
if icon_type == 'external' or icon_type == 'file':
|
||||
url = page_icon[icon_type]['url']
|
||||
icon = {
|
||||
'type': 'url',
|
||||
'url': url if url.startswith('http') else f'https://www.notion.so{url}'
|
||||
}
|
||||
else:
|
||||
icon = {
|
||||
'type': 'emoji',
|
||||
'emoji': page_icon[icon_type]
|
||||
}
|
||||
else:
|
||||
icon = None
|
||||
parent = page_result['parent']
|
||||
parent_type = parent['type']
|
||||
if parent_type == 'block_id':
|
||||
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
elif parent_type == 'workspace':
|
||||
parent_id = 'root'
|
||||
else:
|
||||
parent_id = parent[parent_type]
|
||||
page = {
|
||||
'page_id': page_id,
|
||||
'page_name': page_name,
|
||||
'page_icon': icon,
|
||||
'parent_id': parent_id,
|
||||
'type': 'page'
|
||||
}
|
||||
pages.append(page)
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result['id']
|
||||
if len(database_result['title']) > 0:
|
||||
page_name = database_result['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
page_icon = database_result['icon']
|
||||
if page_icon:
|
||||
icon_type = page_icon['type']
|
||||
if icon_type == 'external' or icon_type == 'file':
|
||||
url = page_icon[icon_type]['url']
|
||||
icon = {
|
||||
'type': 'url',
|
||||
'url': url if url.startswith('http') else f'https://www.notion.so{url}'
|
||||
}
|
||||
else:
|
||||
icon = {
|
||||
'type': icon_type,
|
||||
icon_type: page_icon[icon_type]
|
||||
}
|
||||
else:
|
||||
icon = None
|
||||
parent = database_result['parent']
|
||||
parent_type = parent['type']
|
||||
if parent_type == 'block_id':
|
||||
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
elif parent_type == 'workspace':
|
||||
parent_id = 'root'
|
||||
else:
|
||||
parent_id = parent[parent_type]
|
||||
page = {
|
||||
'page_id': page_id,
|
||||
'page_name': page_name,
|
||||
'page_icon': icon,
|
||||
'parent_id': parent_id,
|
||||
'type': 'database'
|
||||
}
|
||||
pages.append(page)
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str):
|
||||
data = {
|
||||
'filter': {
|
||||
"value": "page",
|
||||
"property": "object"
|
||||
}
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
'Notion-Version': '2022-06-28',
|
||||
}
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
results = response_json['results']
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
headers = {
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
'Notion-Version': '2022-06-28',
|
||||
}
|
||||
response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers)
|
||||
response_json = response.json()
|
||||
parent = response_json['parent']
|
||||
parent_type = parent['type']
|
||||
if parent_type == 'block_id':
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_database_search(self, access_token: str):
|
||||
data = {
|
||||
'filter': {
|
||||
"value": "database",
|
||||
"property": "object"
|
||||
}
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
'Notion-Version': '2022-06-28',
|
||||
}
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
results = response_json['results']
|
||||
return results
|
||||
@@ -0,0 +1,46 @@
|
||||
"""e08af0a69ccefbb59fa80c778efee300bb780980
|
||||
|
||||
Revision ID: e32f6ccb87c6
|
||||
Revises: a45f4dfde53b
|
||||
Create Date: 2023-06-06 19:58:33.103819
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e32f6ccb87c6'
|
||||
down_revision = '614f77cecc48'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('data_source_bindings',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('access_token', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('source_info', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='source_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('source_info_idx', postgresql_using='gin')
|
||||
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||
|
||||
op.drop_table('data_source_bindings')
|
||||
# ### end Alembic commands ###
|
||||
@@ -190,7 +190,7 @@ class Document(db.Model):
|
||||
doc_type = db.Column(db.String(40), nullable=True)
|
||||
doc_metadata = db.Column(db.JSON, nullable=True)
|
||||
|
||||
DATA_SOURCES = ['upload_file']
|
||||
DATA_SOURCES = ['upload_file', 'notion_import']
|
||||
|
||||
@property
|
||||
def display_status(self):
|
||||
@@ -242,6 +242,8 @@ class Document(db.Model):
|
||||
'created_at': file_detail.created_at.timestamp()
|
||||
}
|
||||
}
|
||||
elif self.data_source_type == 'notion_import':
|
||||
return json.loads(self.data_source_info)
|
||||
return {}
|
||||
|
||||
@property
|
||||
|
||||
21
api/models/source.py
Normal file
21
api/models/source.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
class DataSourceBinding(db.Model):
|
||||
__tablename__ = 'data_source_bindings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
|
||||
db.Index('source_binding_tenant_id_idx', 'tenant_id'),
|
||||
db.Index('source_info_idx', "source_info", postgresql_using='gin')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
access_token = db.Column(db.String(255), nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=False)
|
||||
source_info = db.Column(JSONB, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import datetime
|
||||
import time
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -14,10 +14,12 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
@@ -286,6 +288,24 @@ class DocumentService:
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.batch == batch,
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id
|
||||
).all()
|
||||
|
||||
return documents
|
||||
@staticmethod
|
||||
def get_document_file_detail(file_id: str):
|
||||
file_detail = db.session.query(UploadFile). \
|
||||
filter(UploadFile.id == file_id). \
|
||||
@@ -344,9 +364,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
documents = Document.query.filter_by(dataset_id=dataset_id).all()
|
||||
if documents:
|
||||
return len(documents) + 1
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
if document:
|
||||
return document.position + 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
@@ -363,9 +383,11 @@ class DocumentService:
|
||||
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
IndexBuilder.get_default_service_context(dataset.tenant_id)
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
||||
document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
|
||||
documents.append(document)
|
||||
else:
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
@@ -386,46 +408,114 @@ class DocumentService:
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
|
||||
file_name = ''
|
||||
data_source_info = {}
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
file_id = document_data["data_source"]["info"]
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
|
||||
# save document
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
data_source_info=json.dumps(data_source_info),
|
||||
dataset_process_rule_id=dataset_process_rule.id,
|
||||
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
|
||||
name=file_name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
# created_api_request_id = db.Column(UUID, nullable=True)
|
||||
)
|
||||
document_ids = []
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
for file_id in upload_file_list:
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
db.session.add(document)
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
exist_page_ids = []
|
||||
exist_document = dict()
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='notion_import',
|
||||
enabled=True
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
exist_document[data_source_info['notion_page_id']] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
for page in notion_info['pages']:
|
||||
if page['page_id'] not in exist_page_ids:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page['page_id'],
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
# if page['type'] == 'database':
|
||||
# document.splitting_completed_at = datetime.datetime.utcnow()
|
||||
# document.cleaning_completed_at = datetime.datetime.utcnow()
|
||||
# document.parsing_completed_at = datetime.datetime.utcnow()
|
||||
# document.completed_at = datetime.datetime.utcnow()
|
||||
# document.indexing_status = 'completed'
|
||||
# document.word_count = 0
|
||||
# document.tokens = 0
|
||||
# document.indexing_latency = 0
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
# if page['type'] != 'database':
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
else:
|
||||
exist_document.pop(page['page_id'])
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
document_indexing_task.delay(document.dataset_id, document.id)
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
|
||||
created_from: str, position: int, account: Account, name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=data_source_type,
|
||||
data_source_info=json.dumps(data_source_info),
|
||||
dataset_process_rule_id=process_rule_id,
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
)
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
@@ -460,20 +550,42 @@ class DocumentService:
|
||||
file_name = ''
|
||||
data_source_info = {}
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
file_id = document_data["data_source"]["info"]
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
for file_id in upload_file_list:
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
for page in notion_info['pages']:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page['page_id'],
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
document.data_source_type = document_data["data_source"]["type"]
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@@ -513,15 +625,15 @@ class DocumentService:
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
||||
|
||||
cut_length = 18
|
||||
cut_name = document.name[:cut_length]
|
||||
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
|
||||
dataset.description = 'useful for when you want to answer queries about the ' + document.name
|
||||
cut_name = documents[0].name[:cut_length]
|
||||
dataset.name = cut_name + '...'
|
||||
dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
|
||||
db.session.commit()
|
||||
|
||||
return dataset, document
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
@@ -552,9 +664,15 @@ class DocumentService:
|
||||
if args['data_source']['type'] not in Document.DATA_SOURCES:
|
||||
raise ValueError("Data source type is invalid")
|
||||
|
||||
if 'info_list' not in args['data_source'] or not args['data_source']['info_list']:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if args['data_source']['type'] == 'upload_file':
|
||||
if 'info' not in args['data_source'] or not args['data_source']['info']:
|
||||
raise ValueError("Data source info is required")
|
||||
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']:
|
||||
raise ValueError("File source info is required")
|
||||
if args['data_source']['type'] == 'notion_import':
|
||||
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']:
|
||||
raise ValueError("Notion source info is required")
|
||||
|
||||
@classmethod
|
||||
def process_rule_args_validate(cls, args: dict):
|
||||
@@ -624,3 +742,78 @@ class DocumentService:
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
@classmethod
|
||||
def estimate_args_validate(cls, args: dict):
|
||||
if 'info_list' not in args or not args['info_list']:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if not isinstance(args['info_list'], dict):
|
||||
raise ValueError("Data info is invalid")
|
||||
|
||||
if 'process_rule' not in args or not args['process_rule']:
|
||||
raise ValueError("Process rule is required")
|
||||
|
||||
if not isinstance(args['process_rule'], dict):
|
||||
raise ValueError("Process rule is invalid")
|
||||
|
||||
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
|
||||
raise ValueError("Process rule mode is required")
|
||||
|
||||
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args['process_rule']['mode'] == 'automatic':
|
||||
args['process_rule']['rules'] = {}
|
||||
else:
|
||||
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
|
||||
raise ValueError("Process rule rules is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules'], dict):
|
||||
raise ValueError("Process rule rules is invalid")
|
||||
|
||||
if 'pre_processing_rules' not in args['process_rule']['rules'] \
|
||||
or args['process_rule']['rules']['pre_processing_rules'] is None:
|
||||
raise ValueError("Process rule pre_processing_rules is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
|
||||
raise ValueError("Process rule pre_processing_rules is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts = {}
|
||||
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
|
||||
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
|
||||
raise ValueError("Process rule pre_processing_rules id is required")
|
||||
|
||||
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
|
||||
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
|
||||
raise ValueError("Process rule pre_processing_rules enabled is required")
|
||||
|
||||
if not isinstance(pre_processing_rule['enabled'], bool):
|
||||
raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
|
||||
|
||||
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if 'segmentation' not in args['process_rule']['rules'] \
|
||||
or args['process_rule']['rules']['segmentation'] is None:
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
|
||||
raise ValueError("Process rule segmentation is invalid")
|
||||
|
||||
if 'separator' not in args['process_rule']['rules']['segmentation'] \
|
||||
or not args['process_rule']['rules']['segmentation']['separator']:
|
||||
raise ValueError("Process rule segmentation separator is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
|
||||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
|
||||
or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
||||
raise ValueError("Process rule segmentation max_tokens is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
58
api/tasks/clean_notion_document_task.py
Normal file
58
api/tasks/clean_notion_document_task.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.vector_index import VectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
|
||||
|
||||
@shared_task
|
||||
def clean_notion_document_task(document_ids: List[str], dataset_id: str):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
:param dataset_id: dataset id
|
||||
|
||||
Usage: clean_notion_document_task.delay(document_ids, dataset_id)
|
||||
"""
|
||||
logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception('Document has no dataset')
|
||||
|
||||
vector_index = VectorIndex(dataset=dataset)
|
||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
||||
for document_id in document_ids:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id
|
||||
).first()
|
||||
db.session.delete(document)
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
vector_index.del_nodes(index_node_ids)
|
||||
|
||||
# delete from keyword index
|
||||
if index_node_ids:
|
||||
keyword_table_index.del_nodes(index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format(
|
||||
dataset_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Cleaned document when import form notion document deleted failed")
|
||||
109
api/tasks/document_indexing_sync_task.py
Normal file
109
api/tasks/document_indexing_sync_task.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.data_source.notion import NotionPageReader
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.vector_index import VectorIndex
|
||||
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Dataset, DocumentSegment
|
||||
from models.source import DataSourceBinding
|
||||
|
||||
|
||||
@shared_task
|
||||
def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async update document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: document_indexing_sync_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
logging.info(click.style('Start sync document: {}'.format(document_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
|
||||
if not document:
|
||||
raise NotFound('Document not found')
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type == 'notion_import':
|
||||
if not data_source_info or 'notion_page_id' not in data_source_info \
|
||||
or 'notion_workspace_id' not in data_source_info:
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info['notion_workspace_id']
|
||||
page_id = data_source_info['notion_page_id']
|
||||
page_edited_time = data_source_info['last_edited_time']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == document.tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
||||
last_edited_time = reader.get_page_last_edited_time(page_id)
|
||||
# check the page is updated
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception('Dataset not found')
|
||||
|
||||
vector_index = VectorIndex(dataset=dataset)
|
||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
vector_index.del_nodes(index_node_ids)
|
||||
|
||||
# delete from keyword index
|
||||
if index_node_ids:
|
||||
keyword_table_index.del_nodes(index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Cleaned document when document update data source or process rule failed")
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
||||
except DocumentIsPausedException:
|
||||
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
|
||||
except ProviderTokenNotInitError as e:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e.description)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
logging.exception("consume update document failed")
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
@@ -13,32 +13,36 @@ from models.dataset import Document
|
||||
|
||||
|
||||
@shared_task
|
||||
def document_indexing_task(dataset_id: str, document_id: str):
|
||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
documents = []
|
||||
for document_id in document_ids:
|
||||
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
|
||||
if not document:
|
||||
raise NotFound('Document not found')
|
||||
if not document:
|
||||
raise NotFound('Document not found')
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(document)
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
||||
except DocumentIsPausedException:
|
||||
|
||||
@@ -67,7 +67,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
logging.exception("Cleaned document when document update data source or process rule failed")
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(document)
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
||||
except DocumentIsPausedException:
|
||||
|
||||
@@ -34,7 +34,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
if document.indexing_status in ["waiting", "parsing", "cleaning"]:
|
||||
indexing_runner.run(document)
|
||||
indexing_runner.run([document])
|
||||
elif document.indexing_status == "splitting":
|
||||
indexing_runner.run_in_splitting_status(document)
|
||||
elif document.indexing_status == "indexing":
|
||||
|
||||
Reference in New Issue
Block a user