Source code for langchain_community.retrievers.bm25

from __future__ import annotations

from typing import Any, Callable, Dict, Iterable, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever


[docs]def default_preprocessing_func(text: str) -> List[str]: return text.split()
[docs]class BM25Retriever(BaseRetriever): """`BM25` 检索器,不使用Elasticsearch。""" vectorizer: Any """BM25向量化器。""" docs: List[Document] = Field(repr=False) """文档列表。""" k: int = 4 """返回的文档数量。""" preprocess_func: Callable[[str], List[str]] = default_preprocessing_func """在BM25向量化之前用于文本的预处理函数。""" class Config: """此pydantic对象的配置。""" arbitrary_types_allowed = True
[docs] @classmethod def from_texts( cls, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None, bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25Retriever: """从文本列表创建一个BM25Retriever。 参数: texts:要进行向量化的文本列表。 metadatas:要与每个文本关联的元数据字典列表。 bm25_params:传递给BM25向量化器的参数。 preprocess_func:在向量化之前对每个文本进行预处理的函数。 **kwargs:要传递给检索器的其他参数。 返回: 一个BM25Retriever实例。 """ try: from rank_bm25 import BM25Okapi except ImportError: raise ImportError( "Could not import rank_bm25, please install with `pip install " "rank_bm25`." ) texts_processed = [preprocess_func(t) for t in texts] bm25_params = bm25_params or {} vectorizer = BM25Okapi(texts_processed, **bm25_params) metadatas = metadatas or ({} for _ in texts) docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] return cls( vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs )
[docs] @classmethod def from_documents( cls, documents: Iterable[Document], *, bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25Retriever: """从文档列表创建一个BM25Retriever。 参数: documents: 一个要进行向量化的文档列表。 bm25_params: 传递给BM25向量化器的参数。 preprocess_func: 在向量化之前对每个文本进行预处理的函数。 **kwargs: 传递给检索器的任何其他参数。 返回: 一个BM25Retriever实例。 """ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) return cls.from_texts( texts=texts, bm25_params=bm25_params, metadatas=metadatas, preprocess_func=preprocess_func, **kwargs, )
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: processed_query = self.preprocess_func(query) return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) return return_docs