Source code for langchain_community.llms.amazon_api_gateway
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra
from langchain_community.llms.utils import enforce_stop_tokens
[docs]class ContentHandlerAmazonAPIGateway:
"""适配器,用于将来自Langchain的输入准备成LLM模型期望的格式。
它还提供了一个辅助函数,用于从模型响应中提取生成的文本。"""
[docs] @classmethod
def transform_input(
cls, prompt: str, model_kwargs: Dict[str, Any]
) -> Dict[str, Any]:
return {"inputs": prompt, "parameters": model_kwargs}
[docs] @classmethod
def transform_output(cls, response: Any) -> str:
return response.json()[0]["generated_text"]
[docs]class AmazonAPIGateway(LLM):
"""使用Amazon API Gateway访问托管在AWS上的LLM模型。"""
api_url: str
"""API网关URL"""
headers: Optional[Dict] = None
"""API网关发送的HTTP标头,例如用于身份验证。"""
model_kwargs: Optional[Dict] = None
"""传递给模型的关键字参数。"""
content_handler: ContentHandlerAmazonAPIGateway = ContentHandlerAmazonAPIGateway()
"""提供输入和输出转换函数以处理LLM和端点之间的格式的内容处理程序类。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
_model_kwargs = self.model_kwargs or {}
return {
**{"api_url": self.api_url, "headers": self.headers},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "amazon_api_gateway"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用Amazon API Gateway模型。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停用词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = se("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
payload = self.content_handler.transform_input(prompt, _model_kwargs)
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
)
text = self.content_handler.transform_output(response)
except Exception as error:
raise ValueError(f"Error raised by the service: {error}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text