Files
dify/api/core/rag/splitter/text_splitter.py
2026-04-02 05:07:32 +00:00

291 lines
11 KiB
Python

from __future__ import annotations
import copy
import logging
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set
from dataclasses import dataclass
from typing import Any, Literal
from core.rag.models.document import BaseDocumentTransformer, Document
logger = logging.getLogger(__name__)
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({re.escape(separator)})", text)
splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 != 0:
splits += _splits[-1:]
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if (s not in {"", "\n"})]
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
keep_separator: bool = False,
add_start_index: bool = False,
):
"""Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether to keep the separator in the chunks
add_start_index: If `True`, includes chunk's start index in metadata
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
@abstractmethod
def split_text(self, text: str) -> list[str]:
"""Split text into multiple components."""
def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = -1
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
index = text.find(chunk, index + 1)
metadata["start_index"] = index
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
def split_documents(self, documents: Iterable[Document]) -> list[Document]:
"""Split documents."""
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata or {})
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: list[str], separator: str) -> str | None:
text = separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function([separator])[0]
docs = []
current_doc: list[str] = []
total = 0
for d, _len in zip(splits, lengths):
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
if total > self._chunk_size:
logger.warning(
"Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
):
total -= self._length_function([current_doc[0]])[0] + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
@classmethod
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. Please install it with `pip install transformers`."
)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
# @dataclass(frozen=True, kw_only=True, slots=True)
@dataclass(frozen=True)
class Tokenizer:
chunk_overlap: int
tokens_per_chunk: int
decode: Callable[[list[int]], str]
encode: Callable[[str], list[int]]
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: list[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
class TokenTextSplitter(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]:
return self._tokenizer.encode(
_text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=_encode,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
class RecursiveCharacterTextSplitter(TextSplitter):
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(
self,
separators: list[str] | None = None,
keep_separator: bool = True,
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
def _split_text(self, text: str, separators: list[str]) -> list[str]:
final_chunks = []
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
if _s == "":
separator = _s
break
if re.search(_s, text):
separator = _s
new_separators = separators[i + 1 :]
break
splits = _split_text_with_regex(text, separator, self._keep_separator)
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
for s, s_len in zip(splits, s_lens):
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> list[str]:
return self._split_text(text, self._separators)