Source code for langchain_community.embeddings.sagemaker_endpoint

from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

from langchain_community.llms.sagemaker_endpoint import ContentHandlerBase


[docs]class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]): """LLM类的内容处理程序。"""
[docs]class SagemakerEndpointEmbeddings(BaseModel, Embeddings): """自定义Sagemaker推理端点。 要使用,必须提供部署的Sagemaker模型的端点名称和部署的区域。 要进行身份验证,AWS客户端使用以下方法自动加载凭据: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 如果应该使用特定的凭据配置文件,必须传递要使用的位于~/.aws/credentials文件中的配置文件的名称。 确保使用的凭据/角色具有访问Sagemaker端点所需的策略。 参见:https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html""" """ 示例: .. code-block:: python from langchain_community.embeddings import SagemakerEndpointEmbeddings endpoint_name = ( "my-endpoint-name" ) region_name = ( "us-west-2" ) credentials_profile_name = ( "default" ) se = SagemakerEndpointEmbeddings( endpoint_name=endpoint_name, region_name=region_name, credentials_profile_name=credentials_profile_name ) # 与boto3客户端一起使用 client = boto3.client( "sagemaker-runtime", region_name=region_name ) se = SagemakerEndpointEmbeddings( endpoint_name=endpoint_name, client=client )""" client: Any = None endpoint_name: str = "" """部署的Sagemaker模型的端点名称。 在AWS区域内必须是唯一的。""" region_name: str = "" """Sagemaker模型部署的AWS区域,例如`us-west-2`。""" credentials_profile_name: Optional[str] = None """~/.aws/credentials 或 ~/.aws/config 文件中配置文件的名称,其中指定了访问密钥或角色信息。 如果未指定,则将使用默认凭据配置文件,或者如果在EC2实例上,则将使用来自IMDS的凭据。 参见:https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html""" content_handler: EmbeddingsContentHandler """提供输入和输出转换函数以处理LLM和端点之间的格式的内容处理程序类。""" """```python 示例: .. code-block:: python from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler class ContentHandler(EmbeddingsContentHandler): content_type = "application/json" accepts = "application/json" def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes: input_str = json.dumps({prompts: prompts, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> List[List[float]]: response_json = json.loads(output.read().decode("utf-8")) return response_json["vectors"] ```""" # noqa: E501 model_kwargs: Optional[Dict] = None """传递给模型的关键字参数。""" endpoint_kwargs: Optional[Dict] = None """传递给invoke_endpoint函数的可选属性。有关更多信息,请参阅`boto3`文档。 .. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator() def validate_environment(cls, values: Dict) -> Dict: """如果客户端是外部提供的,则不执行任何操作。""" if values.get("client") is not None: return values """Validate that AWS credentials to and python package exists in environment.""" try: import boto3 try: if values["credentials_profile_name"] is not None: session = boto3.Session( profile_name=values["credentials_profile_name"] ) else: # use default credentials session = boto3.Session() values["client"] = session.client( "sagemaker-runtime", region_name=values["region_name"] ) except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " f"profile name are valid. {e}" ) from e except ImportError: raise ImportError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) return values def _embedding_func(self, texts: List[str]) -> List[List[float]]: """调用SageMaker推理嵌入端点。""" # replace newlines, which can negatively affect performance. texts = list(map(lambda x: x.replace("\n", " "), texts)) _model_kwargs = self.model_kwargs or {} _endpoint_kwargs = self.endpoint_kwargs or {} body = self.content_handler.transform_input(texts, _model_kwargs) content_type = self.content_handler.content_type accepts = self.content_handler.accepts # send request try: response = self.client.invoke_endpoint( EndpointName=self.endpoint_name, Body=body, ContentType=content_type, Accept=accepts, **_endpoint_kwargs, ) except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") return self.content_handler.transform_output(response["Body"])
[docs] def embed_documents( self, texts: List[str], chunk_size: int = 64 ) -> List[List[float]]: """使用SageMaker推理端点计算文档嵌入。 参数: texts:要嵌入的文本列表。 chunk_size:块大小定义了将多少个输入文本作为请求分组在一起。如果为None,将使用类指定的块大小。 返回: 嵌入列表,每个文本对应一个嵌入。 """ results = [] _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size for i in range(0, len(texts), _chunk_size): response = self._embedding_func(texts[i : i + _chunk_size]) results.extend(response) return results
[docs] def embed_query(self, text: str) -> List[float]: """使用SageMaker推理端点计算查询嵌入。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ return self._embedding_func([text])[0]