Source code for langchain_community.retrievers.bm25
from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
[docs]def default_preprocessing_func(text: str) -> List[str]:
return text.split()
[docs]class BM25Retriever(BaseRetriever):
"""`BM25` 检索器,不使用Elasticsearch。"""
vectorizer: Any
"""BM25向量化器。"""
docs: List[Document] = Field(repr=False)
"""文档列表。"""
k: int = 4
"""返回的文档数量。"""
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
"""在BM25向量化之前用于文本的预处理函数。"""
class Config:
"""此pydantic对象的配置。"""
arbitrary_types_allowed = True
[docs] @classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25Retriever:
"""从文本列表创建一个BM25Retriever。
参数:
texts:要进行向量化的文本列表。
metadatas:要与每个文本关联的元数据字典列表。
bm25_params:传递给BM25向量化器的参数。
preprocess_func:在向量化之前对每个文本进行预处理的函数。
**kwargs:要传递给检索器的其他参数。
返回:
一个BM25Retriever实例。
"""
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"Could not import rank_bm25, please install with `pip install "
"rank_bm25`."
)
texts_processed = [preprocess_func(t) for t in texts]
bm25_params = bm25_params or {}
vectorizer = BM25Okapi(texts_processed, **bm25_params)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs
)
[docs] @classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25Retriever:
"""从文档列表创建一个BM25Retriever。
参数:
documents: 一个要进行向量化的文档列表。
bm25_params: 传递给BM25向量化器的参数。
preprocess_func: 在向量化之前对每个文本进行预处理的函数。
**kwargs: 传递给检索器的任何其他参数。
返回:
一个BM25Retriever实例。
"""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts,
bm25_params=bm25_params,
metadatas=metadatas,
preprocess_func=preprocess_func,
**kwargs,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
processed_query = self.preprocess_func(query)
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
return return_docs