Source code for langchain_community.retrievers.embedchain
"""包装Embedchain Retriever的封装器。"""
from __future__ import annotations
from typing import Any, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
[docs]class EmbedchainRetriever(BaseRetriever):
"""`Embedchain` 检索器。"""
client: Any
"""Embedchain流水线."""
[docs] @classmethod
def create(cls, yaml_path: Optional[str] = None) -> EmbedchainRetriever:
"""从一个YAML配置文件创建一个EmbedchainRetriever。
参数:
yaml_path: YAML配置文件的路径。如果未提供,则使用默认配置。
返回:
EmbedchainRetriever的一个实例。
"""
from embedchain import Pipeline
# Create an Embedchain Pipeline instance
if yaml_path:
client = Pipeline.from_config(yaml_path=yaml_path)
else:
client = Pipeline()
return cls(client=client)
[docs] def add_texts(
self,
texts: Iterable[str],
) -> List[str]:
"""运行更多的文本通过嵌入,并添加到检索器中。
参数:
texts:要添加到检索器中的字符串/URL的可迭代对象。
返回:
将文本添加到检索器中后的id列表。
"""
ids = []
for text in texts:
_id = self.client.add(text)
ids.append(_id)
return ids
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
res = self.client.search(query)
docs = []
for r in res:
docs.append(
Document(
page_content=r["context"],
metadata={
"source": r["metadata"]["url"],
"document_id": r["metadata"]["doc_id"],
},
)
)
return docs