Source code for langchain_community.llms.sagemaker_endpoint

"""Sagemaker调用端点API。"""
import io
import json
from abc import abstractmethod
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator

from langchain_community.llms.utils import enforce_stop_tokens

INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])


[docs]class LineIterator: """解析字节流输入。 模型的输出将采用以下格式: b'{"outputs": [" a"]} ' b'{"outputs": [" challenging"]} ' b'{"outputs": [" problem"]} ' ... 通常,事件流中的每个PayloadPart事件将包含一个带有完整json的字节数组,但不能保证所有的json对象都会完整地出现在PayloadPart事件中。 例如: {'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'[" problem"]} '}} 该类通过连接通过'write'函数写入的字节,并暴露一个方法,该方法将通过'scan_lines'函数在缓冲区中返回行(以' '字符结尾)。 它维护上次读取位置的位置,以确保不再次暴露先前的字节。 更多详情请参见: https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ """
[docs] def __init__(self, stream: Any) -> None: self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0
def __iter__(self) -> "LineIterator": return self def __next__(self) -> Any: while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: chunk = next(self.byte_iterator) except StopIteration: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise if "PayloadPart" not in chunk: # Unknown Event Type continue self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"])
[docs]class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): """处理程序类,用于将LLM输入转换为SageMaker端点期望的格式。 同样,该类处理将SageMaker端点的输出转换为LLM类期望的格式。""" """```python class ContentHandler(ContentHandlerBase): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: input_str = json.dumps({prompt: prompt, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) return response_json[0]["generated_text"] ```""" content_type: Optional[str] = "text/plain" """传递给端点的输入数据的MIME类型。""" accepts: Optional[str] = "text/plain" """从端点返回的响应数据的MIME类型"""
[docs] @abstractmethod def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes: """将输入转换为模型可以接受的格式,作为请求主体。应该以 content_type 请求头中指定的格式返回字节或可寻址文件对象。 """
[docs] @abstractmethod def transform_output(self, output: bytes) -> OUTPUT_TYPE: """将模型的输出转换为LLM类所期望的字符串。 """
[docs]class LLMContentHandler(ContentHandlerBase[str, str]): """LLM类的内容处理程序。"""
[docs]class SagemakerEndpoint(LLM): """Sagemaker推理端点模型。 要使用,必须提供部署的Sagemaker模型的端点名称和部署的区域。 要进行身份验证,AWS客户端使用以下方法自动加载凭据: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 如果应该使用特定的凭据配置文件,必须传递要使用的~/.aws/credentials文件中的配置文件名称。 确保使用的凭据/角色具有访问Sagemaker端点所需的策略。 参见:https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html""" """Args: region_name: AWS区域,例如`us-west-2`。 回退到AWS_DEFAULT_REGION环境变量 或在~/.aws/config中指定的区域。 credentials_profile_name: 在~/.aws/credentials 或~/.aws/config文件中的配置文件名称,其中包含访问密钥或角色信息。 如果未指定,则将使用默认凭据配置文件,或者如果在 EC2实例上,则使用IMDS中的凭据。 client: Sagemaker Endpoint的boto3客户端 content_handler: 用于特定模型的LLMContentHandler实现 Example: .. code-block:: python from langchain_community.llms import SagemakerEndpoint endpoint_name = ( "my-endpoint-name" ) region_name = ( "us-west-2" ) credentials_profile_name = ( "default" ) se = SagemakerEndpoint( endpoint_name=endpoint_name, region_name=region_name, credentials_profile_name=credentials_profile_name ) # 使用boto3客户端 client = boto3.client( "sagemaker-runtime", region_name=region_name ) se = SagemakerEndpoint( endpoint_name=endpoint_name, client=client )""" client: Any = None """Sagemaker运行时的Boto3客户端""" endpoint_name: str = "" """部署的Sagemaker模型的端点名称。 在AWS区域内必须是唯一的。""" region_name: str = "" """Sagemaker模型部署的AWS区域,例如`us-west-2`。""" credentials_profile_name: Optional[str] = None """~/.aws/credentials 或 ~/.aws/config 文件中配置文件的名称,其中指定了访问密钥或角色信息。 如果未指定,则将使用默认凭据配置文件,或者如果在EC2实例上,则将使用来自IMDS的凭据。 参见:https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html""" content_handler: LLMContentHandler """提供输入和输出转换函数以处理LLM和端点之间的格式的内容处理程序类。""" streaming: bool = False """是否流式传输结果。""" """```python 示例: .. 代码块::python from langchain_community.llms.sagemaker_endpoint import LLMContentHandler class ContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: input_str = json.dumps({prompt: prompt, **model_kwargs}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) return response_json[0]["generated_text"] ```""" model_kwargs: Optional[Dict] = None """传递给模型的关键字参数。""" endpoint_kwargs: Optional[Dict] = None """传递给invoke_endpoint函数的可选属性。有关更多信息,请参阅`boto3`文档。 .. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """如果客户端是外部提供的,则不执行任何操作。""" if values.get("client") is not None: return values """Validate that AWS credentials to and python package exists in environment.""" try: import boto3 try: if values["credentials_profile_name"] is not None: session = boto3.Session( profile_name=values["credentials_profile_name"] ) else: # use default credentials session = boto3.Session() values["client"] = session.client( "sagemaker-runtime", region_name=values["region_name"] ) except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " "profile name are valid." ) from e except ImportError: raise ImportError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" _model_kwargs = self.model_kwargs or {} return { **{"endpoint_name": self.endpoint_name}, **{"model_kwargs": _model_kwargs}, } @property def _llm_type(self) -> str: """llm的返回类型。""" return "sagemaker_endpoint" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用Sagemaker推理端点。 参数: prompt: 传递给模型的提示。 stop: 在生成时可选的停止词列表。 返回: 模型生成的字符串。 示例: .. code-block:: python response = se("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} _model_kwargs = {**_model_kwargs, **kwargs} _endpoint_kwargs = self.endpoint_kwargs or {} body = self.content_handler.transform_input(prompt, _model_kwargs) content_type = self.content_handler.content_type accepts = self.content_handler.accepts if self.streaming and run_manager: try: resp = self.client.invoke_endpoint_with_response_stream( EndpointName=self.endpoint_name, Body=body, ContentType=self.content_handler.content_type, **_endpoint_kwargs, ) iterator = LineIterator(resp["Body"]) current_completion: str = "" for line in iterator: resp = json.loads(line) resp_output = resp.get("outputs")[0] if stop is not None: # Uses same approach as below resp_output = enforce_stop_tokens(resp_output, stop) current_completion += resp_output run_manager.on_llm_new_token(resp_output) return current_completion except Exception as e: raise ValueError(f"Error raised by streaming inference endpoint: {e}") else: try: response = self.client.invoke_endpoint( EndpointName=self.endpoint_name, Body=body, ContentType=content_type, Accept=accepts, **_endpoint_kwargs, ) except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") text = self.content_handler.transform_output(response["Body"]) if stop is not None: # This is a bit hacky, but I can't figure out a better way to enforce # stop tokens when making calls to the sagemaker endpoint. text = enforce_stop_tokens(text, stop) return text