Source code for langchain_community.retrievers.bedrock

from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.retrievers import BaseRetriever


[docs]class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] """向量搜索的配置。""" numberOfResults: int = 4
[docs]class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] """检索配置。""" vectorSearchConfiguration: VectorSearchConfig
[docs]class AmazonKnowledgeBasesRetriever(BaseRetriever): """`Amazon Bedrock Knowledge Bases` 检索。 请访问 https://aws.amazon.com/bedrock/knowledge-bases 了解更多信息。 参数: knowledge_base_id: 知识库ID。 region_name: AWS区域,例如 `us-west-2`。 回退到 AWS_DEFAULT_REGION 环境变量或~/.aws/config中指定的区域。 credentials_profile_name: ~/.aws/credentials 或~/.aws/config文件中的配置文件名称,其中指定了访问密钥或角色信息。如果未指定,将使用默认凭据配置文件,或者如果在EC2实例上,则使用IMDS中的凭据。 client: bedrock代理运行时的boto3客户端。 retrieval_config: 检索的配置。 示例: .. code-block:: python from langchain_community.retrievers import AmazonKnowledgeBasesRetriever retriever = AmazonKnowledgeBasesRetriever( knowledge_base_id="<knowledge-base-id>", retrieval_config={ "vectorSearchConfiguration": { "numberOfResults": 4 } }, ) """ knowledge_base_id: str region_name: Optional[str] = None credentials_profile_name: Optional[str] = None endpoint_url: Optional[str] = None client: Any retrieval_config: RetrievalConfig @root_validator(pre=True) def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: if values.get("client") is not None: return values try: import boto3 from botocore.client import Config from botocore.exceptions import UnknownServiceError if values.get("credentials_profile_name"): session = boto3.Session(profile_name=values["credentials_profile_name"]) else: # use default credentials session = boto3.Session() client_params = { "config": Config( connect_timeout=120, read_timeout=120, retries={"max_attempts": 0} ) } if values.get("region_name"): client_params["region_name"] = values["region_name"] if values.get("endpoint_url"): client_params["endpoint_url"] = values["endpoint_url"] values["client"] = session.client("bedrock-agent-runtime", **client_params) return values except ImportError: raise ImportError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) except UnknownServiceError as e: raise ImportError( "Ensure that you have installed the latest boto3 package " "that contains the API for `bedrock-runtime-agent`." ) from e except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " "profile name are valid." ) from e def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, retrievalConfiguration=self.retrieval_config.dict(), ) results = response["retrievalResults"] documents = [] for result in results: content = result["content"]["text"] result.pop("content") if "score" not in result: result["score"] = 0 if "metadata" in result: result["source_metadata"] = result.pop("metadata") documents.append( Document( page_content=content, metadata=result, ) ) return documents