mirror of
https://github.com/langgenius/dify.git
synced 2026-04-05 15:29:21 +08:00
feat: upgrade langchain (#430)
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
175
api/core/index/vector_index/base.py
Normal file
175
api/core/index/vector_index/base.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import List, Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from weaviate import UnexpectedStatusCodeException
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseVectorIndex(BaseIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||
super().__init__(dataset)
|
||||
self._embeddings = embeddings
|
||||
self._vector_store = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_struct(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
|
||||
if search_type == 'similarity_score_threshold':
|
||||
score_threshold = search_kwargs.get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
search_kwargs['score_threshold'] = .0
|
||||
|
||||
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
|
||||
query, **search_kwargs
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, similarity in docs_with_similarity:
|
||||
doc.metadata['score'] = similarity
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
# similarity k
|
||||
# mmr k, fetch_k, lambda_mult
|
||||
# similarity_score_threshold k
|
||||
return vector_store.as_retriever(
|
||||
search_type=search_type,
|
||||
search_kwargs=search_kwargs
|
||||
).get_relevant_documents(query)
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.as_retriever(**kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
if kwargs.get('duplicate_check', False):
|
||||
texts = self._filter_duplicate_texts(texts)
|
||||
|
||||
uuids = self._get_uuids(texts)
|
||||
vector_store.add_documents(texts, uuids=uuids)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
return False
|
||||
|
||||
def recreate_dataset(self, dataset: Dataset):
|
||||
logging.info(f"Recreating dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except UnexpectedStatusCodeException as e:
|
||||
if e.status_code != 400:
|
||||
# 400 means index not exists
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
origin_index_struct = self.dataset.index_struct
|
||||
self.dataset.index_struct = None
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
self.dataset.index_struct = origin_index_struct
|
||||
raise e
|
||||
|
||||
dataset.index_struct = json.dumps(self.to_index_struct())
|
||||
|
||||
db.session.commit()
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
116
api/core/index/vector_index/qdrant_vector_index.py
Normal file
116
api/core/index/vector_index/qdrant_vector_index.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
from typing import Optional, Any, List, cast
|
||||
|
||||
import qdrant_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
}
|
||||
|
||||
|
||||
class QdrantVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
return self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Index_" + dataset_id.replace("-", "_")
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='text',
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='text'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchValue(value=document_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
if class_prefix.startswith('Vector_'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
69
api/core/index/vector_index/vector_index.py
Normal file
69
api/core/index/vector_index/vector_index.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
|
||||
self._dataset = dataset
|
||||
self._embeddings = embeddings
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings)
|
||||
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError(f"Vector store must be specified.")
|
||||
|
||||
if vector_type == "weaviate":
|
||||
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
|
||||
|
||||
return WeaviateVectorIndex(
|
||||
dataset=dataset,
|
||||
config=WeaviateConfig(
|
||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
return QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if not self._dataset.index_struct_dict:
|
||||
self._vector_index.create(texts, **kwargs)
|
||||
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
self._vector_index.add_texts(texts, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._vector_index is not None:
|
||||
method = getattr(self._vector_index, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
|
||||
|
||||
132
api/core/index/vector_index/weaviate_vector_index.py
Normal file
132
api/core/index/vector_index/weaviate_vector_index.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
import weaviate
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=config.batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return WeaviateVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user