import re
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
root_validator,
validator,
)
from langchain_core.retrievers import BaseRetriever
from typing_extensions import Annotated
[docs]def clean_excerpt(excerpt: str) -> str:
"""清理 Kendra 的摘录。
参数:
excerpt: 需要清理的摘录。
返回:
清理后的摘录。
"""
if not excerpt:
return excerpt
res = re.sub(r"\s+", " ", excerpt).replace("...", "")
return res
[docs]def combined_text(item: "ResultItem") -> str:
"""将ResultItem的标题和摘录合并为一个字符串。
参数:
item: Kendra搜索的ResultItem。
返回:
给定item的标题和摘录的合并文本。
"""
text = ""
title = item.get_title()
if title:
text += f"Document Title: {title}\n"
excerpt = clean_excerpt(item.get_excerpt())
if excerpt:
text += f"Document Excerpt: \n{excerpt}\n"
return text
DocumentAttributeValueType = Union[str, int, List[str], None]
"""Possible types of a DocumentAttributeValue.
Dates are also represented as str.
"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class Highlight(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""在摘录中突出显示关键词的信息。"""
BeginOffset: int
"""高亮开始的摘录中的基于零的位置。"""
EndOffset: int
"""在摘录中,突出显示结束的零基位置。"""
TopAnswer: Optional[bool]
"""指示结果是否是最佳结果。"""
Type: Optional[str]
"""突出显示类型:STANDARD 或 THESAURUS_SYNONYM。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class TextWithHighLights(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""文本带有高亮显示。"""
Text: str
"""文本。"""
Highlights: Optional[Any]
"""亮点。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class AdditionalResultAttributeValue( # type: ignore[call-arg]
BaseModel, extra=Extra.allow
):
"""额外结果属性的值。"""
TextWithHighlightsValue: TextWithHighLights
"""高亮显示的文本值。
"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class AdditionalResultAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""额外的结果属性。"""
Key: str
"""属性的键。"""
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
"""值的类型。"""
Value: AdditionalResultAttributeValue
"""属性的值。"""
[docs] def get_value_text(self) -> str:
return self.Value.TextWithHighlightsValue.Text
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class DocumentAttributeValue(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""文档属性的值。"""
DateValue: Optional[str]
"""日期表示为 ISO 8601 字符串。"""
LongValue: Optional[int]
"""长整型数值。"""
StringListValue: Optional[List[str]]
"""字符串列表的值."""
StringValue: Optional[str]
"""字符串数值。"""
@property
def value(self) -> DocumentAttributeValueType:
"""唯一定义的文档属性值或None。
根据Amazon Kendra,您只能为文档属性提供一个值。
"""
if self.DateValue:
return self.DateValue
if self.LongValue:
return self.LongValue
if self.StringListValue:
return self.StringListValue
if self.StringValue:
return self.StringValue
return None
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class DocumentAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""文档属性。"""
Key: str
"""属性的键。"""
Value: DocumentAttributeValue
"""属性的值。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg]
"""结果如下所示:
结果项的基类。"""
Id: Optional[str]
"""相关结果项的ID。"""
DocumentId: Optional[str]
"""文档ID。"""
DocumentURI: Optional[str]
"""文档的URI。"""
DocumentAttributes: Optional[List[DocumentAttribute]] = []
"""文档属性"""
ScoreAttributes: Optional[dict]
"""kendra得分置信度"""
[docs] @abstractmethod
def get_title(self) -> str:
"""文档标题。"""
[docs] @abstractmethod
def get_excerpt(self) -> str:
"""从Kendra检索到的文档摘录或段落的原始内容。
"""
[docs] def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
"""文档属性字典。"""
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
[docs] def get_score_attribute(self) -> str:
"""文档评分 置信度"""
if self.ScoreAttributes is not None:
return self.ScoreAttributes["ScoreConfidence"]
else:
return "NOT_AVAILABLE"
[docs] def to_doc(
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
) -> Document:
"""将此项转换为文档。"""
page_content = page_content_formatter(self)
metadata = self.get_additional_metadata()
metadata.update(
{
"result_id": self.Id,
"document_id": self.DocumentId,
"source": self.DocumentURI,
"title": self.get_title(),
"excerpt": self.get_excerpt(),
"document_attributes": self.get_document_attributes_dict(),
"score": self.get_score_attribute(),
}
)
return Document(page_content=page_content, metadata=metadata)
[docs]class QueryResultItem(ResultItem):
"""查询API结果项。"""
DocumentTitle: TextWithHighLights
"""文档标题。"""
FeedbackToken: Optional[str]
"""标识特定查询的特定结果。"""
Format: Optional[str]
"""如果类型是ANSWER,那么格式可以是:
* TABLE:返回表格摘录在TableExcerpt中;
* TEXT:返回文本摘录在DocumentExcerpt中。"""
Type: Optional[str]
"""结果类型:文档或问题答案或答案"""
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
"""与结果相关联的一个或多个附加属性。"""
DocumentExcerpt: Optional[TextWithHighLights]
"""文档摘录。"""
[docs] def get_title(self) -> str:
return self.DocumentTitle.Text
[docs] def get_attribute_value(self) -> str:
if not self.AdditionalAttributes:
return ""
if not self.AdditionalAttributes[0]:
return ""
else:
return self.AdditionalAttributes[0].get_value_text()
[docs] def get_excerpt(self) -> str:
if (
self.AdditionalAttributes
and self.AdditionalAttributes[0].Key == "AnswerText"
):
excerpt = self.get_attribute_value()
elif self.DocumentExcerpt:
excerpt = self.DocumentExcerpt.Text
else:
excerpt = ""
return excerpt
[docs]class RetrieveResultItem(ResultItem):
"""检索API结果项。"""
DocumentTitle: Optional[str]
"""文档标题。"""
Content: Optional[str]
"""项目的内容。"""
[docs] def get_title(self) -> str:
return self.DocumentTitle or ""
[docs] def get_excerpt(self) -> str:
return self.Content or ""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class QueryResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""`Amazon Kendra Query API` 搜索结果。
由以下内容组成:
* 相关的建议答案:可以是文本摘录或表格摘录。
* 匹配的常见问题或来自常见问题文件的问题-答案。
* 包括每个文档摘录和标题的文档。"""
ResultItems: List[QueryResultItem]
"""结果项。"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
[docs]class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""`Amazon Kendra Retrieve API` 搜索结果。
它由以下内容组成:
* 在输入查询条件下给出的相关段落或文本摘录。"""
QueryId: str
"""查询的ID。"""
ResultItems: List[RetrieveResultItem]
"""结果项。"""
KENDRA_CONFIDENCE_MAPPING = {
"NOT_AVAILABLE": 0.0,
"LOW": 0.25,
"MEDIUM": 0.50,
"HIGH": 0.75,
"VERY_HIGH": 1.0,
}
[docs]class AmazonKendraRetriever(BaseRetriever):
"""Amazon Kendra索引检索器。
Args:
index_id: Kendra索引ID
region_name: AWS区域,例如`us-west-2`。
回退到AWS_DEFAULT_REGION环境变量
或者在~/.aws/config中指定的区域。
credentials_profile_name: 在~/.aws/credentials
或~/.aws/config文件中的配置文件名称,其中包含访问密钥或角色信息。
如果未指定,将使用默认凭证配置文件或(如果在EC2实例上)来自IMDS的凭证。
top_k: 返回结果的数量
attribute_filter: 基于元数据的结果的额外过滤
参见: https://docs.aws.amazon.com/kendra/latest/APIReference
document_relevance_override_configurations: 覆盖索引级别设置的字段/属性的相关性调整配置
参见: https://docs.aws.amazon.com/kendra/latest/APIReference
page_content_formatter: 生成文档页面内容
允许访问所有结果项属性。默认情况下,它使用
项的标题和摘录。
client: Kendra的boto3客户端
user_context: 提供有关用户上下文的信息
参见: https://docs.aws.amazon.com/kendra/latest/APIReference
Example:
.. code-block:: python
retriever = AmazonKendraRetriever(
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
)
"""
index_id: str
region_name: Optional[str] = None
credentials_profile_name: Optional[str] = None
top_k: int = 3
attribute_filter: Optional[Dict] = None
document_relevance_override_configurations: Optional[List[Dict]] = None
page_content_formatter: Callable[[ResultItem], str] = combined_text
client: Any
user_context: Optional[Dict] = None
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]
@validator("top_k")
def validate_top_k(cls, value: int) -> int:
if value < 0:
raise ValueError(f"top_k ({value}) cannot be negative.")
return value
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("client") is not None:
return values
try:
import boto3
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
values["client"] = session.client("kendra", **client_params)
return values
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
def _kendra_query(self, query: str) -> Sequence[ResultItem]:
kendra_kwargs = {
"IndexId": self.index_id,
# truncate the query to ensure that
# there is no validation exception from Kendra.
"QueryText": query.strip()[0:999],
"PageSize": self.top_k,
}
if self.attribute_filter is not None:
kendra_kwargs["AttributeFilter"] = self.attribute_filter
if self.document_relevance_override_configurations is not None:
kendra_kwargs[
"DocumentRelevanceOverrideConfigurations"
] = self.document_relevance_override_configurations
if self.user_context is not None:
kendra_kwargs["UserContext"] = self.user_context
response = self.client.retrieve(**kendra_kwargs)
r_result = RetrieveResult.parse_obj(response)
if r_result.ResultItems:
return r_result.ResultItems
# Retrieve API returned 0 results, fall back to Query API
response = self.client.query(**kendra_kwargs)
q_result = QueryResult.parse_obj(response)
return q_result.ResultItems
def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
top_docs = [
item.to_doc(self.page_content_formatter)
for item in result_items[: self.top_k]
]
return top_docs
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
"""过滤掉得分置信度大于所需阈值的记录。
"""
if not self.min_score_confidence:
return docs
filtered_docs = [
item
for item in docs
if (
item.metadata.get("score") is not None
and isinstance(item.metadata["score"], str)
and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0)
>= self.min_score_confidence
)
]
return filtered_docs
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""在Kendra索引上运行搜索并获取前k个文档
示例:
.. code-block:: python
docs = retriever.invoke('这是我的查询')
"""
result_items = self._kendra_query(query)
top_k_docs = self._get_top_k_docs(result_items)
return self._filter_by_score_confidence(top_k_docs)