import json
import urllib.request
import warnings
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
DEFAULT_TIMEOUT = 50
[docs]class AzureMLEndpointClient(object):
"""AzureML托管端点客户端。"""
[docs] def __init__(
self,
endpoint_url: str,
endpoint_api_key: str,
deployment_name: str = "",
timeout: int = DEFAULT_TIMEOUT,
) -> None:
"""初始化类。"""
if not endpoint_api_key or not endpoint_url:
raise ValueError(
"""A key/token and REST endpoint should
be provided to invoke the endpoint"""
)
self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name
self.timeout = timeout
[docs] def call(
self,
body: bytes,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> bytes:
"""调用。"""
# The azureml-model-deployment header will force the request to go to a
# specific deployment. Remove this header to have the request observe the
# endpoint traffic rules.
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self.endpoint_api_key),
}
if self.deployment_name != "":
headers["azureml-model-deployment"] = self.deployment_name
req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(
req, timeout=kwargs.get("timeout", self.timeout)
)
result = response.read()
return result
[docs]class AzureMLEndpointApiType(str, Enum):
"""Azure ML终结点API类型。对于部署在托管基础设施中的模型(也称为Azure机器学习中的在线终结点),请使用`dedicated`,或者对于部署为按使用量计费或PTU的服务的模型,请使用`serverless`。"""
dedicated = "dedicated"
realtime = "realtime" #: Deprecated
serverless = "serverless"
[docs]class ContentFormatterBase:
"""将AzureML端点的请求和响应转换为所需的模式。"""
"""```python
示例:
.. code-block:: python
class ContentFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(
self,
prompt: str,
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes:
input_str = json.dumps(
{
"inputs": {"input_string": [prompt]},
"parameters": model_kwargs,
}
)
return str.encode(input_str)
def format_response_payload(
self, output: str, api_type: AzureMLEndpointApiType
) -> str:
response_json = json.loads(output)
return response_json[0]["0"]
```"""
content_type: Optional[str] = "application/json"
"""输入数据传递到端点的MIME类型"""
accepts: Optional[str] = "application/json"
"""从端点返回的响应数据的MIME类型"""
format_error_msg: str = (
"Error while formatting response payload for chat model of type "
" `{api_type}`. Are you using the right formatter for the deployed "
" model and endpoint type?"
)
[docs] @staticmethod
def escape_special_characters(prompt: str) -> str:
"""转义`prompt`中的任何特殊字符"""
escape_map = {
"\\": "\\\\",
'"': '\\"',
"\b": "\\b",
"\f": "\\f",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
}
# Replace each occurrence of the specified characters with escaped versions
for escape_sequence, escaped_sequence in escape_map.items():
prompt = prompt.replace(escape_sequence, escaped_sequence)
return prompt
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
"""给定格式化程序支持的API。Azure ML支持使用不同的托管方法部署模型。每种方法可能具有不同的API结构。
"""
return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload(
self,
prompt: str,
model_kwargs: Dict,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> Any:
"""根据模型的输入模式格式化请求体。根据请求头中指定的content_type返回字节或可寻址文件对象。
"""
raise NotImplementedError()
[docs] @abstractmethod
def format_response_payload(
self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> Generation:
"""根据模型的输出模式格式化响应主体。返回从响应接收到的数据类型。
"""
[docs]class GPT2ContentFormatter(ContentFormatterBase):
"""GPT2的内容处理程序"""
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
)
return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
[docs]class OSSContentFormatter(GPT2ContentFormatter):
"""已弃用:保留以确保向后兼容
来自OSS目录的LLM的内容处理程序。"""
content_formatter: Any = None
[docs] def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`OSSContentFormatter` will be deprecated in the future.
Please use `GPT2ContentFormatter` instead.
"""
)
[docs]class HFContentFormatter(ContentFormatterBase):
"""HuggingFace目录中LLMs的内容处理程序。"""
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
)
return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]["generated_text"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
[docs]class DollyContentFormatter(ContentFormatterBase):
"""Dolly-v2-12b模型的内容处理程序"""
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps(
{
"input_data": {"input_string": [f'"{prompt}"']},
"parameters": model_kwargs,
}
)
return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
[docs]class CustomOpenAIContentFormatter(ContentFormatterBase):
"""用于使用OpenAI类API方案的模型的内容格式化程序。"""
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
[docs] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
"""根据所选的API格式化请求"""
prompt = ContentFormatterBase.escape_special_characters(prompt)
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
request_payload = json.dumps(
{
"input_data": {
"input_string": [f'"{prompt}"'],
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"prompt": prompt, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
"""格式化响应"""
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
if api_type == AzureMLEndpointApiType.serverless:
try:
choice = json.loads(output)["choices"][0]
if not isinstance(choice, dict):
raise TypeError(
"Endpoint response is not well formed for a chat "
"model. Expected `dict` but `{type(choice)}` was "
"received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(
text=choice["text"].strip(),
generation_info=dict(
finish_reason=choice.get("finish_reason"),
logprobs=choice.get("logprobs"),
),
)
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
[docs]class LlamaContentFormatter(CustomOpenAIContentFormatter):
"""已弃用:保留以确保向后兼容
Llama的内容格式化程序。"""
content_formatter: Any = None
[docs] def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`LlamaContentFormatter` will be deprecated in the future.
Please use `CustomOpenAIContentFormatter` instead.
"""
)
[docs]class AzureMLBaseEndpoint(BaseModel):
"""Azure ML在线端点模型。"""
endpoint_url: str = ""
"""预先存在的终结点的URL。应传递给构造函数或指定为环境变量 `AZUREML_ENDPOINT_URL`。"""
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated
"""消费的端点类型。可能的值为`serverless`,表示按使用付费,`dedicated`表示专用端点。"""
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""用于终结点的身份验证密钥。应传递给构造函数或指定为环境变量 `AZUREML_ENDPOINT_API_KEY`。"""
deployment_name: str = ""
"""终端部署的名称。不需要调用终端。应该传递给构造函数或指定为环境变量 `AZUREML_DEPLOYMENT_NAME`。"""
timeout: int = DEFAULT_TIMEOUT
"""调用端点的请求超时"""
http_client: Any = None #: :meta private:
max_retries: int = 1
content_formatter: Any = None
"""提供输入和输出转换函数以处理LLM和端点之间的格式的内容格式化程序。"""
model_kwargs: Optional[dict] = None
"""传递给模型的关键字参数。"""
@root_validator(pre=True)
def validate_environ(cls, values: Dict) -> Dict:
values["endpoint_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
)
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
values["deployment_name"] = get_from_dict_or_env(
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
)
values["endpoint_api_type"] = get_from_dict_or_env(
values,
"endpoint_api_type",
"AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.dedicated,
)
values["timeout"] = get_from_dict_or_env(
values,
"timeout",
"AZUREML_TIMEOUT",
str(DEFAULT_TIMEOUT),
)
return values
@validator("content_formatter")
def validate_content_formatter(
cls, field_value: Any, values: Dict
) -> ContentFormatterBase:
"""验证内容格式化程序是否受端点类型支持。"""
endpoint_api_type = values.get("endpoint_api_type")
if endpoint_api_type not in field_value.supported_api_types:
raise ValueError(
f"Content formatter f{type(field_value)} is not supported by this "
f"endpoint. Supported types are {field_value.supported_api_types} "
f"but endpoint is {endpoint_api_type}."
)
return field_value
@validator("endpoint_url")
def validate_endpoint_url(cls, field_value: Any) -> str:
"""验证端点URL是否完整。"""
if field_value.endswith("/"):
field_value = field_value[:-1]
if field_value.endswith("inference.ml.azure.com"):
raise ValueError(
"`endpoint_url` should contain the full invocation URL including "
"`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` "
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
)
return field_value
@validator("endpoint_api_type")
def validate_endpoint_api_type(
cls, field_value: Any, values: Dict
) -> AzureMLEndpointApiType:
"""验证端点API类型是否与URL格式兼容。"""
endpoint_url = values.get("endpoint_url")
if (
(
field_value == AzureMLEndpointApiType.dedicated
or field_value == AzureMLEndpointApiType.realtime
)
and not endpoint_url.endswith("/score") # type: ignore[union-attr]
):
raise ValueError(
"Endpoints of type `dedicated` should follow the format "
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
" If your endpoint URL ends with `/v1/completions` or"
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
)
if field_value == AzureMLEndpointApiType.serverless and not (
endpoint_url.endswith("/v1/completions") # type: ignore[union-attr]
or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr]
):
raise ValueError(
"Endpoints of type `serverless` should follow the format "
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
" or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
)
return field_value
@validator("http_client", always=True)
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""验证环境中是否存在API密钥和Python包。"""
endpoint_url = values.get("endpoint_url")
endpoint_key = values.get("endpoint_api_key")
deployment_name = values.get("deployment_name")
timeout = values.get("timeout", DEFAULT_TIMEOUT)
http_client = AzureMLEndpointClient(
endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
timeout, # type: ignore
)
return http_client
[docs]class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
"""Azure ML在线端点模型。
示例:
.. code-block:: python
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.dedicated,
endpoint_api_key="my-api-key",
timeout=120,
content_formatter=content_formatter,
)
""" # noqa: E501
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
_model_kwargs = self.model_kwargs or {}
return {
**{"deployment_name": self.deployment_name},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "azureml_endpoint"
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定提示上运行LLM。
参数:
prompts:传递给模型的提示。
stop:生成时要使用的可选停用词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = azureml_model.invoke("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs.update(kwargs)
if stop:
_model_kwargs["stop"] = stop
generations = []
for prompt in prompts:
request_payload = self.content_formatter.format_request_payload(
prompt, _model_kwargs, self.endpoint_api_type
)
response_payload = self.http_client.call(
body=request_payload, run_manager=run_manager
)
generated_text = self.content_formatter.format_response_payload(
response_payload, self.endpoint_api_type
)
generations.append([generated_text])
return LLMResult(generations=generations)