Source code for langchain_community.retrievers.embedchain

"""包装Embedchain Retriever的封装器。"""

from __future__ import annotations

from typing import Any, Iterable, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever


[docs]class EmbedchainRetriever(BaseRetriever): """`Embedchain` 检索器。""" client: Any """Embedchain流水线."""
[docs] @classmethod def create(cls, yaml_path: Optional[str] = None) -> EmbedchainRetriever: """从一个YAML配置文件创建一个EmbedchainRetriever。 参数: yaml_path: YAML配置文件的路径。如果未提供,则使用默认配置。 返回: EmbedchainRetriever的一个实例。 """ from embedchain import Pipeline # Create an Embedchain Pipeline instance if yaml_path: client = Pipeline.from_config(yaml_path=yaml_path) else: client = Pipeline() return cls(client=client)
[docs] def add_texts( self, texts: Iterable[str], ) -> List[str]: """运行更多的文本通过嵌入,并添加到检索器中。 参数: texts:要添加到检索器中的字符串/URL的可迭代对象。 返回: 将文本添加到检索器中后的id列表。 """ ids = [] for text in texts: _id = self.client.add(text) ids.append(_id) return ids
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: res = self.client.search(query) docs = [] for r in res: docs.append( Document( page_content=r["context"], metadata={ "source": r["metadata"]["url"], "document_id": r["metadata"]["doc_id"], }, ) ) return docs