Source code for langchain_experimental.text_splitter

"""基于语义相似性的实验性 **文本分割器** 。"""
import copy
import re
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast

import numpy as np
from langchain_community.utils.math import (
    cosine_similarity,
)
from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.embeddings import Embeddings


[docs]def combine_sentences(sentences: List[dict], buffer_size: int = 1) -> List[dict]: """根据缓冲区大小合并句子。 Args: sentences: 要合并的句子列表。 buffer_size: 要合并的句子数量。默认为1。 Returns: 合并后的句子列表。 """ # 逐句遍历字典 for i in range(len(sentences)): # 创建一个字符串,用于保存连接后的句子 combined_sentence = "" # 根据缓冲区大小,在当前句子之前添加句子。 for j in range(i - buffer_size, i): # 检查索引 j 是否不为负数 # (避免像第一个这样的索引超出范围错误) if j >= 0: # 将索引为j的句子添加到combined_sentence字符串中 combined_sentence += sentences[j]["sentence"] + " " # 添加当前的句子 combined_sentence += sentences[i]["sentence"] # 在当前的缓冲区大小基础上添加句子。 for j in range(i + 1, i + 1 + buffer_size): # 检查索引 j 是否在句子列表的范围内 if j < len(sentences): # 将索引为j的句子添加到combined_sentence字符串中 combined_sentence += " " + sentences[j]["sentence"] # 然后将整个内容添加到你的字典中 # 将组合后的句子存储在当前句子字典中 sentences[i]["combined_sentence"] = combined_sentence return sentences
[docs]def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List[dict]]: """计算句子之间的余弦距离。 Args: sentences: 要计算距离的句子列表。 Returns: 距离和句子的元组。 """ distances = [] for i in range(len(sentences) - 1): embedding_current = sentences[i]["combined_sentence_embedding"] embedding_next = sentences[i + 1]["combined_sentence_embedding"] # 计算余弦相似度 similarity = cosine_similarity([embedding_current], [embedding_next])[0][0] # 转换为余弦距离 distance = 1 - similarity # 将余弦距离附加到列表中 distances.append(distance) # 将距离存储在字典中 sentences[i]["distance_to_next"] = distance # 可选择地处理最后一句话 # 设置最后一句的下一个距离为 None # 或者默认值 return distances, sentences
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"] BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = { "percentile": 95, "standard_deviation": 3, "interquartile": 1.5, }
[docs]class SemanticChunker(BaseDocumentTransformer): """根据语义相似性分割文本。 参考自Greg Kamradt的优秀笔记: https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/tutorials/LevelsOfTextSplitting/5_Levels_Of_Text_Splitting.ipynb 所有功绩归于他。 在高层次上,这将文本分割为句子,然后分组为三句一组,然后合并在嵌入空间中相似的句子。 """
[docs] def __init__( self, embeddings: Embeddings, buffer_size: int = 1, add_start_index: bool = False, breakpoint_threshold_type: BreakpointThresholdType = "percentile", breakpoint_threshold_amount: Optional[float] = None, number_of_chunks: Optional[int] = None, sentence_split_regex: str = r"(?<=[.?!])\s+", ): self._add_start_index = add_start_index self.embeddings = embeddings self.buffer_size = buffer_size self.breakpoint_threshold_type = breakpoint_threshold_type self.number_of_chunks = number_of_chunks self.sentence_split_regex = sentence_split_regex if breakpoint_threshold_amount is None: self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[ breakpoint_threshold_type ] else: self.breakpoint_threshold_amount = breakpoint_threshold_amount
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float: if self.breakpoint_threshold_type == "percentile": return cast( float, np.percentile(distances, self.breakpoint_threshold_amount), ) elif self.breakpoint_threshold_type == "standard_deviation": return cast( float, np.mean(distances) + self.breakpoint_threshold_amount * np.std(distances), ) elif self.breakpoint_threshold_type == "interquartile": q1, q3 = np.percentile(distances, [25, 75]) iqr = q3 - q1 return np.mean(distances) + self.breakpoint_threshold_amount * iqr else: raise ValueError( f"Got unexpected `breakpoint_threshold_type`: " f"{self.breakpoint_threshold_type}" ) def _threshold_from_clusters(self, distances: List[float]) -> float: """ 根据块的数量计算阈值。 百分位方法的反方法。 """ if self.number_of_chunks is None: raise ValueError( "This should never be called if `number_of_chunks` is None." ) x1, y1 = len(distances), 0.0 x2, y2 = 1.0, 100.0 x = max(min(self.number_of_chunks, x1), x2) # 线性插值公式 # 给定两个点 (x0, y0) 和 (x1, y1),线性插值公式用于计算在 x 轴上给定点 x 处的插值 y 值。 # 公式为:y = y0 + (x - x0) * (y1 - y0) / (x1 - x0) y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1) y = min(max(y, 0), 100) return cast(float, np.percentile(distances, y)) def _calculate_sentence_distances( self, single_sentences_list: List[str] ) -> Tuple[List[float], List[dict]]: """将文本分割成多个部分。""" _sentences = [ {"sentence": x, "index": i} for i, x in enumerate(single_sentences_list) ] sentences = combine_sentences(_sentences, self.buffer_size) embeddings = self.embeddings.embed_documents( [x["combined_sentence"] for x in sentences] ) for i, sentence in enumerate(sentences): sentence["combined_sentence_embedding"] = embeddings[i] return calculate_cosine_distances(sentences)
[docs] def split_text( self, text: str, ) -> List[str]: # 按照默认标点符号('。'、'?'和'!')分割文章。 single_sentences_list = re.split(self.sentence_split_regex, text) # 当 len(single_sentences_list) == 1 时会导致以下情况 # np.percentile 失败。 if len(single_sentences_list) == 1: return single_sentences_list distances, sentences = self._calculate_sentence_distances(single_sentences_list) if self.number_of_chunks is not None: breakpoint_distance_threshold = self._threshold_from_clusters(distances) else: breakpoint_distance_threshold = self._calculate_breakpoint_threshold( distances ) indices_above_thresh = [ i for i, x in enumerate(distances) if x > breakpoint_distance_threshold ] chunks = [] start_index = 0 # 遍历断点以切分句子 for index in indices_above_thresh: # 结束索引是当前的断点 end_index = index # 从当前开始索引到结束索引处切片 sentence_dicts。 group = sentences[start_index : end_index + 1] combined_text = " ".join([d["sentence"] for d in group]) chunks.append(combined_text) # 更新下一组的起始索引 start_index = index + 1 # 如果还有剩余的句子,则处理最后一组。 if start_index < len(sentences): combined_text = " ".join([d["sentence"] for d in sentences[start_index:]]) chunks.append(combined_text) return chunks
[docs] def create_documents( self, texts: List[str], metadatas: Optional[List[dict]] = None ) -> List[Document]: """从文本列表创建文档。""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): index = -1 for chunk in self.split_text(text): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: index = text.find(chunk, index + 1) metadata["start_index"] = index new_doc = Document(page_content=chunk, metadata=metadata) documents.append(new_doc) return documents
[docs] def split_documents(self, documents: Iterable[Document]) -> List[Document]: """拆分文档.""" texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) metadatas.append(doc.metadata) return self.create_documents(texts, metadatas=metadatas)
[docs] def transform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: """通过分割文档的序列来转换它们。""" return self.split_documents(list(documents))