Source code for langchain_community.retrievers.bedrock
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.retrievers import BaseRetriever
[docs]class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
"""向量搜索的配置。"""
numberOfResults: int = 4
[docs]class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
"""检索配置。"""
vectorSearchConfiguration: VectorSearchConfig
[docs]class AmazonKnowledgeBasesRetriever(BaseRetriever):
"""`Amazon Bedrock Knowledge Bases` 检索。
请访问 https://aws.amazon.com/bedrock/knowledge-bases 了解更多信息。
参数:
knowledge_base_id: 知识库ID。
region_name: AWS区域,例如 `us-west-2`。
回退到 AWS_DEFAULT_REGION 环境变量或~/.aws/config中指定的区域。
credentials_profile_name: ~/.aws/credentials 或~/.aws/config文件中的配置文件名称,其中指定了访问密钥或角色信息。如果未指定,将使用默认凭据配置文件,或者如果在EC2实例上,则使用IMDS中的凭据。
client: bedrock代理运行时的boto3客户端。
retrieval_config: 检索的配置。
示例:
.. code-block:: python
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id="<knowledge-base-id>",
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 4
}
},
)
"""
knowledge_base_id: str
region_name: Optional[str] = None
credentials_profile_name: Optional[str] = None
endpoint_url: Optional[str] = None
client: Any
retrieval_config: RetrievalConfig
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("client") is not None:
return values
try:
import boto3
from botocore.client import Config
from botocore.exceptions import UnknownServiceError
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {
"config": Config(
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
)
}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
if values.get("endpoint_url"):
client_params["endpoint_url"] = values["endpoint_url"]
values["client"] = session.client("bedrock-agent-runtime", **client_params)
return values
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except UnknownServiceError as e:
raise ImportError(
"Ensure that you have installed the latest boto3 package "
"that contains the API for `bedrock-runtime-agent`."
) from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = self.client.retrieve(
retrievalQuery={"text": query.strip()},
knowledgeBaseId=self.knowledge_base_id,
retrievalConfiguration=self.retrieval_config.dict(),
)
results = response["retrievalResults"]
documents = []
for result in results:
content = result["content"]["text"]
result.pop("content")
if "score" not in result:
result["score"] = 0
if "metadata" in result:
result["source_metadata"] = result.pop("metadata")
documents.append(
Document(
page_content=content,
metadata=result,
)
)
return documents