Source code for langchain_community.retrievers.kendra

import re
from abc import ABC, abstractmethod
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Union,
)

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    root_validator,
    validator,
)
from langchain_core.retrievers import BaseRetriever
from typing_extensions import Annotated


[docs]def clean_excerpt(excerpt: str) -> str: """清理 Kendra 的摘录。 参数: excerpt: 需要清理的摘录。 返回: 清理后的摘录。 """ if not excerpt: return excerpt res = re.sub(r"\s+", " ", excerpt).replace("...", "") return res
[docs]def combined_text(item: "ResultItem") -> str: """将ResultItem的标题和摘录合并为一个字符串。 参数: item: Kendra搜索的ResultItem。 返回: 给定item的标题和摘录的合并文本。 """ text = "" title = item.get_title() if title: text += f"Document Title: {title}\n" excerpt = clean_excerpt(item.get_excerpt()) if excerpt: text += f"Document Excerpt: \n{excerpt}\n" return text
DocumentAttributeValueType = Union[str, int, List[str], None] """Possible types of a DocumentAttributeValue. Dates are also represented as str. """ # Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class Highlight(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """在摘录中突出显示关键词的信息。""" BeginOffset: int """高亮开始的摘录中的基于零的位置。""" EndOffset: int """在摘录中,突出显示结束的零基位置。""" TopAnswer: Optional[bool] """指示结果是否是最佳结果。""" Type: Optional[str] """突出显示类型:STANDARD 或 THESAURUS_SYNONYM。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class TextWithHighLights(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """文本带有高亮显示。""" Text: str """文本。""" Highlights: Optional[Any] """亮点。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class AdditionalResultAttributeValue( # type: ignore[call-arg] BaseModel, extra=Extra.allow ): """额外结果属性的值。""" TextWithHighlightsValue: TextWithHighLights """高亮显示的文本值。 """
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class AdditionalResultAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """额外的结果属性。""" Key: str """属性的键。""" ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"] """值的类型。""" Value: AdditionalResultAttributeValue """属性的值。"""
[docs] def get_value_text(self) -> str: return self.Value.TextWithHighlightsValue.Text
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class DocumentAttributeValue(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """文档属性的值。""" DateValue: Optional[str] """日期表示为 ISO 8601 字符串。""" LongValue: Optional[int] """长整型数值。""" StringListValue: Optional[List[str]] """字符串列表的值.""" StringValue: Optional[str] """字符串数值。""" @property def value(self) -> DocumentAttributeValueType: """唯一定义的文档属性值或None。 根据Amazon Kendra,您只能为文档属性提供一个值。 """ if self.DateValue: return self.DateValue if self.LongValue: return self.LongValue if self.StringListValue: return self.StringListValue if self.StringValue: return self.StringValue return None
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class DocumentAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """文档属性。""" Key: str """属性的键。""" Value: DocumentAttributeValue """属性的值。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg] """结果如下所示: 结果项的基类。""" Id: Optional[str] """相关结果项的ID。""" DocumentId: Optional[str] """文档ID。""" DocumentURI: Optional[str] """文档的URI。""" DocumentAttributes: Optional[List[DocumentAttribute]] = [] """文档属性""" ScoreAttributes: Optional[dict] """kendra得分置信度"""
[docs] @abstractmethod def get_title(self) -> str: """文档标题。"""
[docs] @abstractmethod def get_excerpt(self) -> str: """从Kendra检索到的文档摘录或段落的原始内容。 """
[docs] def get_additional_metadata(self) -> dict: """文档的额外元数据字典。 返回除以下内容之外的任何额外元数据: * result_id * document_id * source * title * excerpt * document_attributes """ return {}
[docs] def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]: """文档属性字典。""" return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
[docs] def get_score_attribute(self) -> str: """文档评分 置信度""" if self.ScoreAttributes is not None: return self.ScoreAttributes["ScoreConfidence"] else: return "NOT_AVAILABLE"
[docs] def to_doc( self, page_content_formatter: Callable[["ResultItem"], str] = combined_text ) -> Document: """将此项转换为文档。""" page_content = page_content_formatter(self) metadata = self.get_additional_metadata() metadata.update( { "result_id": self.Id, "document_id": self.DocumentId, "source": self.DocumentURI, "title": self.get_title(), "excerpt": self.get_excerpt(), "document_attributes": self.get_document_attributes_dict(), "score": self.get_score_attribute(), } ) return Document(page_content=page_content, metadata=metadata)
[docs]class QueryResultItem(ResultItem): """查询API结果项。""" DocumentTitle: TextWithHighLights """文档标题。""" FeedbackToken: Optional[str] """标识特定查询的特定结果。""" Format: Optional[str] """如果类型是ANSWER,那么格式可以是: * TABLE:返回表格摘录在TableExcerpt中; * TEXT:返回文本摘录在DocumentExcerpt中。""" Type: Optional[str] """结果类型:文档或问题答案或答案""" AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = [] """与结果相关联的一个或多个附加属性。""" DocumentExcerpt: Optional[TextWithHighLights] """文档摘录。"""
[docs] def get_title(self) -> str: return self.DocumentTitle.Text
[docs] def get_attribute_value(self) -> str: if not self.AdditionalAttributes: return "" if not self.AdditionalAttributes[0]: return "" else: return self.AdditionalAttributes[0].get_value_text()
[docs] def get_excerpt(self) -> str: if ( self.AdditionalAttributes and self.AdditionalAttributes[0].Key == "AnswerText" ): excerpt = self.get_attribute_value() elif self.DocumentExcerpt: excerpt = self.DocumentExcerpt.Text else: excerpt = "" return excerpt
[docs] def get_additional_metadata(self) -> dict: additional_metadata = {"type": self.Type} return additional_metadata
[docs]class RetrieveResultItem(ResultItem): """检索API结果项。""" DocumentTitle: Optional[str] """文档标题。""" Content: Optional[str] """项目的内容。"""
[docs] def get_title(self) -> str: return self.DocumentTitle or ""
[docs] def get_excerpt(self) -> str: return self.Content or ""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class QueryResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """`Amazon Kendra Query API` 搜索结果。 由以下内容组成: * 相关的建议答案:可以是文本摘录或表格摘录。 * 匹配的常见问题或来自常见问题文件的问题-答案。 * 包括每个文档摘录和标题的文档。""" ResultItems: List[QueryResultItem] """结果项。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg] """`Amazon Kendra Retrieve API` 搜索结果。 它由以下内容组成: * 在输入查询条件下给出的相关段落或文本摘录。""" QueryId: str """查询的ID。""" ResultItems: List[RetrieveResultItem] """结果项。"""
KENDRA_CONFIDENCE_MAPPING = { "NOT_AVAILABLE": 0.0, "LOW": 0.25, "MEDIUM": 0.50, "HIGH": 0.75, "VERY_HIGH": 1.0, }
[docs]class AmazonKendraRetriever(BaseRetriever): """Amazon Kendra索引检索器。 Args: index_id: Kendra索引ID region_name: AWS区域,例如`us-west-2`。 回退到AWS_DEFAULT_REGION环境变量 或者在~/.aws/config中指定的区域。 credentials_profile_name: 在~/.aws/credentials 或~/.aws/config文件中的配置文件名称,其中包含访问密钥或角色信息。 如果未指定,将使用默认凭证配置文件或(如果在EC2实例上)来自IMDS的凭证。 top_k: 返回结果的数量 attribute_filter: 基于元数据的结果的额外过滤 参见: https://docs.aws.amazon.com/kendra/latest/APIReference document_relevance_override_configurations: 覆盖索引级别设置的字段/属性的相关性调整配置 参见: https://docs.aws.amazon.com/kendra/latest/APIReference page_content_formatter: 生成文档页面内容 允许访问所有结果项属性。默认情况下,它使用 项的标题和摘录。 client: Kendra的boto3客户端 user_context: 提供有关用户上下文的信息 参见: https://docs.aws.amazon.com/kendra/latest/APIReference Example: .. code-block:: python retriever = AmazonKendraRetriever( index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03" ) """ index_id: str region_name: Optional[str] = None credentials_profile_name: Optional[str] = None top_k: int = 3 attribute_filter: Optional[Dict] = None document_relevance_override_configurations: Optional[List[Dict]] = None page_content_formatter: Callable[[ResultItem], str] = combined_text client: Any user_context: Optional[Dict] = None min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)] @validator("top_k") def validate_top_k(cls, value: int) -> int: if value < 0: raise ValueError(f"top_k ({value}) cannot be negative.") return value @root_validator(pre=True) def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: if values.get("client") is not None: return values try: import boto3 if values.get("credentials_profile_name"): session = boto3.Session(profile_name=values["credentials_profile_name"]) else: # use default credentials session = boto3.Session() client_params = {} if values.get("region_name"): client_params["region_name"] = values["region_name"] values["client"] = session.client("kendra", **client_params) return values except ImportError: raise ImportError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " "profile name are valid." ) from e def _kendra_query(self, query: str) -> Sequence[ResultItem]: kendra_kwargs = { "IndexId": self.index_id, # truncate the query to ensure that # there is no validation exception from Kendra. "QueryText": query.strip()[0:999], "PageSize": self.top_k, } if self.attribute_filter is not None: kendra_kwargs["AttributeFilter"] = self.attribute_filter if self.document_relevance_override_configurations is not None: kendra_kwargs[ "DocumentRelevanceOverrideConfigurations" ] = self.document_relevance_override_configurations if self.user_context is not None: kendra_kwargs["UserContext"] = self.user_context response = self.client.retrieve(**kendra_kwargs) r_result = RetrieveResult.parse_obj(response) if r_result.ResultItems: return r_result.ResultItems # Retrieve API returned 0 results, fall back to Query API response = self.client.query(**kendra_kwargs) q_result = QueryResult.parse_obj(response) return q_result.ResultItems def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]: top_docs = [ item.to_doc(self.page_content_formatter) for item in result_items[: self.top_k] ] return top_docs def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: """过滤掉得分置信度大于所需阈值的记录。 """ if not self.min_score_confidence: return docs filtered_docs = [ item for item in docs if ( item.metadata.get("score") is not None and isinstance(item.metadata["score"], str) and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0) >= self.min_score_confidence ) ] return filtered_docs def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """在Kendra索引上运行搜索并获取前k个文档 示例: .. code-block:: python docs = retriever.invoke('这是我的查询') """ result_items = self._kendra_query(query) top_k_docs = self._get_top_k_docs(result_items) return self._filter_by_score_confidence(top_k_docs)