Source code for langchain_community.llms.azureml_endpoint

import json
import urllib.request
import warnings
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

DEFAULT_TIMEOUT = 50


[docs]class AzureMLEndpointClient(object): """AzureML托管端点客户端。"""
[docs] def __init__( self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = "", timeout: int = DEFAULT_TIMEOUT, ) -> None: """初始化类。""" if not endpoint_api_key or not endpoint_url: raise ValueError( """A key/token and REST endpoint should be provided to invoke the endpoint""" ) self.endpoint_url = endpoint_url self.endpoint_api_key = endpoint_api_key self.deployment_name = deployment_name self.timeout = timeout
[docs] def call( self, body: bytes, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> bytes: """调用。""" # The azureml-model-deployment header will force the request to go to a # specific deployment. Remove this header to have the request observe the # endpoint traffic rules. headers = { "Content-Type": "application/json", "Authorization": ("Bearer " + self.endpoint_api_key), } if self.deployment_name != "": headers["azureml-model-deployment"] = self.deployment_name req = urllib.request.Request(self.endpoint_url, body, headers) response = urllib.request.urlopen( req, timeout=kwargs.get("timeout", self.timeout) ) result = response.read() return result
[docs]class AzureMLEndpointApiType(str, Enum): """Azure ML终结点API类型。对于部署在托管基础设施中的模型(也称为Azure机器学习中的在线终结点),请使用`dedicated`,或者对于部署为按使用量计费或PTU的服务的模型,请使用`serverless`。""" dedicated = "dedicated" realtime = "realtime" #: Deprecated serverless = "serverless"
[docs]class ContentFormatterBase: """将AzureML端点的请求和响应转换为所需的模式。""" """```python 示例: .. code-block:: python class ContentFormatter(ContentFormatterBase): content_type = "application/json" accepts = "application/json" def format_request_payload( self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType, ) -> bytes: input_str = json.dumps( { "inputs": {"input_string": [prompt]}, "parameters": model_kwargs, } ) return str.encode(input_str) def format_response_payload( self, output: str, api_type: AzureMLEndpointApiType ) -> str: response_json = json.loads(output) return response_json[0]["0"] ```""" content_type: Optional[str] = "application/json" """输入数据传递到端点的MIME类型""" accepts: Optional[str] = "application/json" """从端点返回的响应数据的MIME类型""" format_error_msg: str = ( "Error while formatting response payload for chat model of type " " `{api_type}`. Are you using the right formatter for the deployed " " model and endpoint type?" )
[docs] @staticmethod def escape_special_characters(prompt: str) -> str: """转义`prompt`中的任何特殊字符""" escape_map = { "\\": "\\\\", '"': '\\"', "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", "\t": "\\t", } # Replace each occurrence of the specified characters with escaped versions for escape_sequence, escaped_sequence in escape_map.items(): prompt = prompt.replace(escape_sequence, escaped_sequence) return prompt
@property def supported_api_types(self) -> List[AzureMLEndpointApiType]: """给定格式化程序支持的API。Azure ML支持使用不同的托管方法部署模型。每种方法可能具有不同的API结构。 """ return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated, ) -> Any: """根据模型的输入模式格式化请求体。根据请求头中指定的content_type返回字节或可寻址文件对象。 """ raise NotImplementedError()
[docs] @abstractmethod def format_response_payload( self, output: bytes, api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated, ) -> Generation: """根据模型的输出模式格式化响应主体。返回从响应接收到的数据类型。 """
[docs]class GPT2ContentFormatter(ContentFormatterBase): """GPT2的内容处理程序""" @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: prompt = ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( {"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs} ) return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: try: choice = json.loads(output)[0]["0"] except (KeyError, IndexError, TypeError) as e: raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice)
[docs]class OSSContentFormatter(GPT2ContentFormatter): """已弃用:保留以确保向后兼容 来自OSS目录的LLM的内容处理程序。""" content_formatter: Any = None
[docs] def __init__(self) -> None: super().__init__() warnings.warn( """`OSSContentFormatter` will be deprecated in the future. Please use `GPT2ContentFormatter` instead. """ )
[docs]class HFContentFormatter(ContentFormatterBase): """HuggingFace目录中LLMs的内容处理程序。""" @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( {"inputs": [f'"{prompt}"'], "parameters": model_kwargs} ) return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: try: choice = json.loads(output)[0]["0"]["generated_text"] except (KeyError, IndexError, TypeError) as e: raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice)
[docs]class DollyContentFormatter(ContentFormatterBase): """Dolly-v2-12b模型的内容处理程序""" @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.dedicated]
[docs] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: prompt = ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( { "input_data": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs, } ) return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: try: choice = json.loads(output)[0] except (KeyError, IndexError, TypeError) as e: raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice)
[docs]class CustomOpenAIContentFormatter(ContentFormatterBase): """用于使用OpenAI类API方案的模型的内容格式化程序。""" @property def supported_api_types(self) -> List[AzureMLEndpointApiType]: return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
[docs] def format_request_payload( # type: ignore[override] self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType ) -> bytes: """根据所选的API格式化请求""" prompt = ContentFormatterBase.escape_special_characters(prompt) if api_type in [ AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.realtime, ]: request_payload = json.dumps( { "input_data": { "input_string": [f'"{prompt}"'], "parameters": model_kwargs, } } ) elif api_type == AzureMLEndpointApiType.serverless: request_payload = json.dumps({"prompt": prompt, **model_kwargs}) else: raise ValueError( f"`api_type` {api_type} is not supported by this formatter" ) return str.encode(request_payload)
[docs] def format_response_payload( # type: ignore[override] self, output: bytes, api_type: AzureMLEndpointApiType ) -> Generation: """格式化响应""" if api_type in [ AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.realtime, ]: try: choice = json.loads(output)[0]["0"] except (KeyError, IndexError, TypeError) as e: raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation(text=choice) if api_type == AzureMLEndpointApiType.serverless: try: choice = json.loads(output)["choices"][0] if not isinstance(choice, dict): raise TypeError( "Endpoint response is not well formed for a chat " "model. Expected `dict` but `{type(choice)}` was " "received." ) except (KeyError, IndexError, TypeError) as e: raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] return Generation( text=choice["text"].strip(), generation_info=dict( finish_reason=choice.get("finish_reason"), logprobs=choice.get("logprobs"), ), ) raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
[docs]class LlamaContentFormatter(CustomOpenAIContentFormatter): """已弃用:保留以确保向后兼容 Llama的内容格式化程序。""" content_formatter: Any = None
[docs] def __init__(self) -> None: super().__init__() warnings.warn( """`LlamaContentFormatter` will be deprecated in the future. Please use `CustomOpenAIContentFormatter` instead. """ )
[docs]class AzureMLBaseEndpoint(BaseModel): """Azure ML在线端点模型。""" endpoint_url: str = "" """预先存在的终结点的URL。应传递给构造函数或指定为环境变量 `AZUREML_ENDPOINT_URL`。""" endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated """消费的端点类型。可能的值为`serverless`,表示按使用付费,`dedicated`表示专用端点。""" endpoint_api_key: SecretStr = convert_to_secret_str("") """用于终结点的身份验证密钥。应传递给构造函数或指定为环境变量 `AZUREML_ENDPOINT_API_KEY`。""" deployment_name: str = "" """终端部署的名称。不需要调用终端。应该传递给构造函数或指定为环境变量 `AZUREML_DEPLOYMENT_NAME`。""" timeout: int = DEFAULT_TIMEOUT """调用端点的请求超时""" http_client: Any = None #: :meta private: max_retries: int = 1 content_formatter: Any = None """提供输入和输出转换函数以处理LLM和端点之间的格式的内容格式化程序。""" model_kwargs: Optional[dict] = None """传递给模型的关键字参数。""" @root_validator(pre=True) def validate_environ(cls, values: Dict) -> Dict: values["endpoint_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY") ) values["endpoint_url"] = get_from_dict_or_env( values, "endpoint_url", "AZUREML_ENDPOINT_URL" ) values["deployment_name"] = get_from_dict_or_env( values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", "" ) values["endpoint_api_type"] = get_from_dict_or_env( values, "endpoint_api_type", "AZUREML_ENDPOINT_API_TYPE", AzureMLEndpointApiType.dedicated, ) values["timeout"] = get_from_dict_or_env( values, "timeout", "AZUREML_TIMEOUT", str(DEFAULT_TIMEOUT), ) return values @validator("content_formatter") def validate_content_formatter( cls, field_value: Any, values: Dict ) -> ContentFormatterBase: """验证内容格式化程序是否受端点类型支持。""" endpoint_api_type = values.get("endpoint_api_type") if endpoint_api_type not in field_value.supported_api_types: raise ValueError( f"Content formatter f{type(field_value)} is not supported by this " f"endpoint. Supported types are {field_value.supported_api_types} " f"but endpoint is {endpoint_api_type}." ) return field_value @validator("endpoint_url") def validate_endpoint_url(cls, field_value: Any) -> str: """验证端点URL是否完整。""" if field_value.endswith("/"): field_value = field_value[:-1] if field_value.endswith("inference.ml.azure.com"): raise ValueError( "`endpoint_url` should contain the full invocation URL including " "`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` " "or `/v1/chat/completions` for `endpoint_api_type='serverless'`" ) return field_value @validator("endpoint_api_type") def validate_endpoint_api_type( cls, field_value: Any, values: Dict ) -> AzureMLEndpointApiType: """验证端点API类型是否与URL格式兼容。""" endpoint_url = values.get("endpoint_url") if ( ( field_value == AzureMLEndpointApiType.dedicated or field_value == AzureMLEndpointApiType.realtime ) and not endpoint_url.endswith("/score") # type: ignore[union-attr] ): raise ValueError( "Endpoints of type `dedicated` should follow the format " "`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`." " If your endpoint URL ends with `/v1/completions` or" "`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead." ) if field_value == AzureMLEndpointApiType.serverless and not ( endpoint_url.endswith("/v1/completions") # type: ignore[union-attr] or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr] ): raise ValueError( "Endpoints of type `serverless` should follow the format " "`https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`" " or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`" ) return field_value @validator("http_client", always=True) def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient: """验证环境中是否存在API密钥和Python包。""" endpoint_url = values.get("endpoint_url") endpoint_key = values.get("endpoint_api_key") deployment_name = values.get("deployment_name") timeout = values.get("timeout", DEFAULT_TIMEOUT) http_client = AzureMLEndpointClient( endpoint_url, # type: ignore endpoint_key.get_secret_value(), # type: ignore deployment_name, # type: ignore timeout, # type: ignore ) return http_client
[docs]class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint): """Azure ML在线端点模型。 示例: .. code-block:: python azure_llm = AzureMLOnlineEndpoint( endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_api_type=AzureMLApiType.dedicated, endpoint_api_key="my-api-key", timeout=120, content_formatter=content_formatter, ) """ # noqa: E501 @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" _model_kwargs = self.model_kwargs or {} return { **{"deployment_name": self.deployment_name}, **{"model_kwargs": _model_kwargs}, } @property def _llm_type(self) -> str: """llm的返回类型。""" return "azureml_endpoint" def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """在给定提示上运行LLM。 参数: prompts:传递给模型的提示。 stop:生成时要使用的可选停用词列表。 返回: 模型生成的字符串。 示例: .. code-block:: python response = azureml_model.invoke("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} _model_kwargs.update(kwargs) if stop: _model_kwargs["stop"] = stop generations = [] for prompt in prompts: request_payload = self.content_formatter.format_request_payload( prompt, _model_kwargs, self.endpoint_api_type ) response_payload = self.http_client.call( body=request_payload, run_manager=run_manager ) generated_text = self.content_formatter.format_response_payload( response_payload, self.endpoint_api_type ) generations.append([generated_text]) return LLMResult(generations=generations)