Source code for langchain_community.embeddings.sagemaker_endpoint
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.llms.sagemaker_endpoint import ContentHandlerBase
[docs]class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
"""LLM类的内容处理程序。"""
[docs]class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""自定义Sagemaker推理端点。
要使用,必须提供部署的Sagemaker模型的端点名称和部署的区域。
要进行身份验证,AWS客户端使用以下方法自动加载凭据:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
如果应该使用特定的凭据配置文件,必须传递要使用的位于~/.aws/credentials文件中的配置文件的名称。
确保使用的凭据/角色具有访问Sagemaker端点所需的策略。
参见:https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html"""
""" 示例:
.. code-block:: python
from langchain_community.embeddings import SagemakerEndpointEmbeddings
endpoint_name = (
"my-endpoint-name"
)
region_name = (
"us-west-2"
)
credentials_profile_name = (
"default"
)
se = SagemakerEndpointEmbeddings(
endpoint_name=endpoint_name,
region_name=region_name,
credentials_profile_name=credentials_profile_name
)
# 与boto3客户端一起使用
client = boto3.client(
"sagemaker-runtime",
region_name=region_name
)
se = SagemakerEndpointEmbeddings(
endpoint_name=endpoint_name,
client=client
)"""
client: Any = None
endpoint_name: str = ""
"""部署的Sagemaker模型的端点名称。
在AWS区域内必须是唯一的。"""
region_name: str = ""
"""Sagemaker模型部署的AWS区域,例如`us-west-2`。"""
credentials_profile_name: Optional[str] = None
"""~/.aws/credentials 或 ~/.aws/config 文件中配置文件的名称,其中指定了访问密钥或角色信息。
如果未指定,则将使用默认凭据配置文件,或者如果在EC2实例上,则将使用来自IMDS的凭据。
参见:https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html"""
content_handler: EmbeddingsContentHandler
"""提供输入和输出转换函数以处理LLM和端点之间的格式的内容处理程序类。"""
"""```python
示例:
.. code-block:: python
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompts: prompts, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
```""" # noqa: E501
model_kwargs: Optional[Dict] = None
"""传递给模型的关键字参数。"""
endpoint_kwargs: Optional[Dict] = None
"""传递给invoke_endpoint函数的可选属性。有关更多信息,请参阅`boto3`文档。
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""如果客户端是外部提供的,则不执行任何操作。"""
if values.get("client") is not None:
return values
"""Validate that AWS credentials to and python package exists in environment."""
try:
import boto3
try:
if values["credentials_profile_name"] is not None:
session = boto3.Session(
profile_name=values["credentials_profile_name"]
)
else:
# use default credentials
session = boto3.Session()
values["client"] = session.client(
"sagemaker-runtime", region_name=values["region_name"]
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. {e}"
) from e
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
return values
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
"""调用SageMaker推理嵌入端点。"""
# replace newlines, which can negatively affect performance.
texts = list(map(lambda x: x.replace("\n", " "), texts))
_model_kwargs = self.model_kwargs or {}
_endpoint_kwargs = self.endpoint_kwargs or {}
body = self.content_handler.transform_input(texts, _model_kwargs)
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts
# send request
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
return self.content_handler.transform_output(response["Body"])
[docs] def embed_documents(
self, texts: List[str], chunk_size: int = 64
) -> List[List[float]]:
"""使用SageMaker推理端点计算文档嵌入。
参数:
texts:要嵌入的文本列表。
chunk_size:块大小定义了将多少个输入文本作为请求分组在一起。如果为None,将使用类指定的块大小。
返回:
嵌入列表,每个文本对应一个嵌入。
"""
results = []
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
for i in range(0, len(texts), _chunk_size):
response = self._embedding_func(texts[i : i + _chunk_size])
results.extend(response)
return results
[docs] def embed_query(self, text: str) -> List[float]:
"""使用SageMaker推理端点计算查询嵌入。
参数:
text:要嵌入的文本。
返回:
文本的嵌入。
"""
return self._embedding_func([text])[0]