Source code for langchain_community.retrievers.kay
from __future__ import annotations
from typing import Any, List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
[docs]class KayAiRetriever(BaseRetriever):
"""用于Kay.ai数据集的检索器。
为了正常工作,需要设置KAY_API_KEY环境变量。
您可以在https://kay.ai/免费获取一个。"""
client: Any
num_contexts: int
[docs] @classmethod
def create(
cls,
dataset_id: str,
data_types: List[str],
num_contexts: int = 6,
) -> KayAiRetriever:
"""给定一个Kay数据集ID和数据源列表,创建一个KayRetriever。
参数:
dataset_id:Kay中的数据集ID类别,如"company"
data_types:数据集中存在的数据源列表。对于"company",相应的数据源可能是["10-K", "10-Q", "8-K", "PressRelease"]。
num_contexts:每次查询要检索的文档数量。默认为6。
"""
try:
from kay.rag.retrievers import KayRetriever
except ImportError:
raise ImportError(
"Could not import kay python package. Please install it with "
"`pip install kay`.",
)
client = KayRetriever(dataset_id, data_types)
return cls(client=client, num_contexts=num_contexts)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
ctxs = self.client.query(query=query, num_context=self.num_contexts)
docs = []
for ctx in ctxs:
page_content = ctx.pop("chunk_embed_text", None)
if page_content is None:
continue
docs.append(Document(page_content=page_content, metadata={**ctx}))
return docs