Source code for langchain_community.retrievers.kay

from __future__ import annotations

from typing import Any, List

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever


[docs]class KayAiRetriever(BaseRetriever): """用于Kay.ai数据集的检索器。 为了正常工作,需要设置KAY_API_KEY环境变量。 您可以在https://kay.ai/免费获取一个。""" client: Any num_contexts: int
[docs] @classmethod def create( cls, dataset_id: str, data_types: List[str], num_contexts: int = 6, ) -> KayAiRetriever: """给定一个Kay数据集ID和数据源列表,创建一个KayRetriever。 参数: dataset_id:Kay中的数据集ID类别,如"company" data_types:数据集中存在的数据源列表。对于"company",相应的数据源可能是["10-K", "10-Q", "8-K", "PressRelease"]。 num_contexts:每次查询要检索的文档数量。默认为6。 """ try: from kay.rag.retrievers import KayRetriever except ImportError: raise ImportError( "Could not import kay python package. Please install it with " "`pip install kay`.", ) client = KayRetriever(dataset_id, data_types) return cls(client=client, num_contexts=num_contexts)
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: ctxs = self.client.query(query=query, num_context=self.num_contexts) docs = [] for ctx in ctxs: page_content = ctx.pop("chunk_embed_text", None) if page_content is None: continue docs.append(Document(page_content=page_content, metadata={**ctx})) return docs