"""Sagemaker调用端点API。"""
import io
import json
from abc import abstractmethod
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_community.llms.utils import enforce_stop_tokens
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
[docs]class LineIterator:
"""解析字节流输入。
模型的输出将采用以下格式:
b'{"outputs": [" a"]}
'
b'{"outputs": [" challenging"]}
'
b'{"outputs": [" problem"]}
'
...
通常,事件流中的每个PayloadPart事件将包含一个带有完整json的字节数组,但不能保证所有的json对象都会完整地出现在PayloadPart事件中。
例如:
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}
'}}
该类通过连接通过'write'函数写入的字节,并暴露一个方法,该方法将通过'scan_lines'函数在缓冲区中返回行(以'
'字符结尾)。
它维护上次读取位置的位置,以确保不再次暴露先前的字节。
更多详情请参见:
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
"""
[docs] def __init__(self, stream: Any) -> None:
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self) -> "LineIterator":
return self
def __next__(self) -> Any:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
# Unknown Event Type
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
[docs]class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
"""处理程序类,用于将LLM输入转换为SageMaker端点期望的格式。
同样,该类处理将SageMaker端点的输出转换为LLM类期望的格式。"""
"""```python
class ContentHandler(ContentHandlerBase):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
```"""
content_type: Optional[str] = "text/plain"
"""传递给端点的输入数据的MIME类型。"""
accepts: Optional[str] = "text/plain"
"""从端点返回的响应数据的MIME类型"""
[docs] @abstractmethod
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
"""将输入转换为模型可以接受的格式,作为请求主体。应该以 content_type 请求头中指定的格式返回字节或可寻址文件对象。
"""
[docs] @abstractmethod
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
"""将模型的输出转换为LLM类所期望的字符串。
"""
[docs]class LLMContentHandler(ContentHandlerBase[str, str]):
"""LLM类的内容处理程序。"""
[docs]class SagemakerEndpoint(LLM):
"""Sagemaker推理端点模型。
要使用,必须提供部署的Sagemaker模型的端点名称和部署的区域。
要进行身份验证,AWS客户端使用以下方法自动加载凭据:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
如果应该使用特定的凭据配置文件,必须传递要使用的~/.aws/credentials文件中的配置文件名称。
确保使用的凭据/角色具有访问Sagemaker端点所需的策略。
参见:https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html"""
"""Args:
region_name: AWS区域,例如`us-west-2`。
回退到AWS_DEFAULT_REGION环境变量
或在~/.aws/config中指定的区域。
credentials_profile_name: 在~/.aws/credentials
或~/.aws/config文件中的配置文件名称,其中包含访问密钥或角色信息。
如果未指定,则将使用默认凭据配置文件,或者如果在
EC2实例上,则使用IMDS中的凭据。
client: Sagemaker Endpoint的boto3客户端
content_handler: 用于特定模型的LLMContentHandler实现
Example:
.. code-block:: python
from langchain_community.llms import SagemakerEndpoint
endpoint_name = (
"my-endpoint-name"
)
region_name = (
"us-west-2"
)
credentials_profile_name = (
"default"
)
se = SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region_name,
credentials_profile_name=credentials_profile_name
)
# 使用boto3客户端
client = boto3.client(
"sagemaker-runtime",
region_name=region_name
)
se = SagemakerEndpoint(
endpoint_name=endpoint_name,
client=client
)"""
client: Any = None
"""Sagemaker运行时的Boto3客户端"""
endpoint_name: str = ""
"""部署的Sagemaker模型的端点名称。
在AWS区域内必须是唯一的。"""
region_name: str = ""
"""Sagemaker模型部署的AWS区域,例如`us-west-2`。"""
credentials_profile_name: Optional[str] = None
"""~/.aws/credentials 或 ~/.aws/config 文件中配置文件的名称,其中指定了访问密钥或角色信息。
如果未指定,则将使用默认凭据配置文件,或者如果在EC2实例上,则将使用来自IMDS的凭据。
参见:https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html"""
content_handler: LLMContentHandler
"""提供输入和输出转换函数以处理LLM和端点之间的格式的内容处理程序类。"""
streaming: bool = False
"""是否流式传输结果。"""
"""```python
示例:
.. 代码块::python
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
```"""
model_kwargs: Optional[Dict] = None
"""传递给模型的关键字参数。"""
endpoint_kwargs: Optional[Dict] = None
"""传递给invoke_endpoint函数的可选属性。有关更多信息,请参阅`boto3`文档。
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""如果客户端是外部提供的,则不执行任何操作。"""
if values.get("client") is not None:
return values
"""Validate that AWS credentials to and python package exists in environment."""
try:
import boto3
try:
if values["credentials_profile_name"] is not None:
session = boto3.Session(
profile_name=values["credentials_profile_name"]
)
else:
# use default credentials
session = boto3.Session()
values["client"] = session.client(
"sagemaker-runtime", region_name=values["region_name"]
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_name": self.endpoint_name},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "sagemaker_endpoint"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用Sagemaker推理端点。
参数:
prompt: 传递给模型的提示。
stop: 在生成时可选的停止词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = se("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {}
body = self.content_handler.transform_input(prompt, _model_kwargs)
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts
if self.streaming and run_manager:
try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=body,
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
iterator = LineIterator(resp["Body"])
current_completion: str = ""
for line in iterator:
resp = json.loads(line)
resp_output = resp.get("outputs")[0]
if stop is not None:
# Uses same approach as below
resp_output = enforce_stop_tokens(resp_output, stop)
current_completion += resp_output
run_manager.on_llm_new_token(resp_output)
return current_completion
except Exception as e:
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
else:
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
text = self.content_handler.transform_output(response["Body"])
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to the sagemaker endpoint.
text = enforce_stop_tokens(text, stop)
return text