Source code for langchain_community.retrievers.dria_index
"""Dria Retriever的封装器。
"""
from typing import Any, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities import DriaAPIWrapper
[docs]class DriaRetriever(BaseRetriever):
"""使用DriaAPIWrapper检索Dria。"""
api_wrapper: DriaAPIWrapper
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
"""使用DriaAPIWrapper实例初始化DriaRetriever。
参数:
api_key:Dria的API密钥。
contract_id:要交互的知识库的合同ID。
"""
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
super().__init__(api_wrapper=api_wrapper, **kwargs) # type: ignore[call-arg]
[docs] def create_knowledge_base(
self,
name: str,
description: str,
category: str = "Unspecified",
embedding: str = "jina",
) -> str:
"""在Dria中创建一个新的知识库。
参数:
name: 知识库的名称。
description: 知识库的描述。
category: 知识库的类别。
embedding: 用于知识库的嵌入模型。
返回:
创建的知识库的ID。
"""
response = self.api_wrapper.create_knowledge_base(
name, description, category, embedding
)
return response
[docs] def add_texts(
self,
texts: List,
) -> None:
"""向Dria知识库添加文本。
参数:
texts:要添加到知识库的文本和元数据的可迭代对象。
返回:
代表已添加文本的ID列表。
"""
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
self.api_wrapper.insert_data(data)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""从Dria根据查询检索相关文档。
参数:
query:在知识库中搜索的查询字符串。
run_manager:用于检索运行的回调管理器。
返回:
包含搜索结果的文档列表。
"""
results = self.api_wrapper.search(query)
docs = [
Document(
page_content=result["metadata"],
metadata={"id": result["id"], "score": result["score"]},
)
for result in results
]
return docs