Source code for langchain_community.llms.pai_eas_endpoint

import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_community.llms.utils import enforce_stop_tokens

logger = logging.getLogger(__name__)


[docs]class PaiEasEndpoint(LLM): """Langchain LLM类用于帮助访问EAS LLM服务。 要使用此端点,必须在PAI AliCloud上部署了EAS Chat LLM服务。 可以设置环境变量``eas_service_url``和``eas_service_token``。 环境变量可以设置为您的EAS服务URL和服务令牌。 示例: .. code-block:: python from langchain_community.llms.pai_eas_endpoint import PaiEasEndpoint eas_chat_endpoint = PaiEasChatEndpoint( eas_service_url="your_service_url", eas_service_token="your_service_token" )""" """PAI-EAS 服务 URL""" eas_service_url: str """PAI-EAS服务令牌""" eas_service_token: str """PAI-EAS服务推断参数""" max_new_tokens: Optional[int] = 512 temperature: Optional[float] = 0.95 top_p: Optional[float] = 0.1 top_k: Optional[int] = 0 stop_sequences: Optional[List[str]] = None """启用实时聊天模式。""" streaming: bool = False """传递给模型的键/值参数。保留供将来使用。""" model_kwargs: Optional[dict] = None version: Optional[str] = "2.0" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" values["eas_service_url"] = get_from_dict_or_env( values, "eas_service_url", "EAS_SERVICE_URL" ) values["eas_service_token"] = get_from_dict_or_env( values, "eas_service_token", "EAS_SERVICE_TOKEN" ) return values @property def _llm_type(self) -> str: """llm的返回类型。""" return "pai_eas_endpoint" @property def _default_params(self) -> Dict[str, Any]: """获取调用Cohere API的默认参数。""" return { "max_new_tokens": self.max_new_tokens, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, "stop_sequences": [], } @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" _model_kwargs = self.model_kwargs or {} return { "eas_service_url": self.eas_service_url, "eas_service_token": self.eas_service_token, **_model_kwargs, } def _invocation_params( self, stop_sequences: Optional[List[str]], **kwargs: Any ) -> dict: params = self._default_params if self.stop_sequences is not None and stop_sequences is not None: raise ValueError("`stop` found in both the input and default params.") elif self.stop_sequences is not None: params["stop"] = self.stop_sequences else: params["stop"] = stop_sequences if self.model_kwargs: params.update(self.model_kwargs) return {**params, **kwargs} @staticmethod def _process_response( response: Any, stop: Optional[List[str]], version: Optional[str] ) -> str: if version == "1.0": text = response else: text = response["response"] if stop: text = enforce_stop_tokens(text, stop) return "".join(text) def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: params = self._invocation_params(stop, **kwargs) prompt = prompt.strip() response = None try: if self.streaming: completion = "" for chunk in self._stream(prompt, stop, run_manager, **params): completion += chunk.text return completion else: response = self._call_eas(prompt, params) _stop = params.get("stop") return self._process_response(response, _stop, self.version) except Exception as error: raise ValueError(f"Error raised by the service: {error}") def _call_eas(self, prompt: str = "", params: Dict = {}) -> Any: """从eas服务生成文本。""" headers = { "Content-Type": "application/json", "Authorization": f"{self.eas_service_token}", } if self.version == "1.0": body = { "input_ids": f"{prompt}", } else: body = { "prompt": f"{prompt}", } # add params to body for key, value in params.items(): body[key] = value # make request response = requests.post(self.eas_service_url, headers=headers, json=body) if response.status_code != 200: raise Exception( f"Request failed with status code {response.status_code}" f" and message {response.text}" ) try: return json.loads(response.text) except Exception as e: if isinstance(e, json.decoder.JSONDecodeError): return response.text raise e def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: invocation_params = self._invocation_params(stop, **kwargs) headers = { "User-Agent": "Test Client", "Authorization": f"{self.eas_service_token}", } if self.version == "1.0": pload = {"input_ids": prompt, **invocation_params} response = requests.post( self.eas_service_url, headers=headers, json=pload, stream=True ) res = GenerationChunk(text=response.text) if run_manager: run_manager.on_llm_new_token(res.text) # yield text, if any yield res else: pload = {"prompt": prompt, "use_stream_chat": "True", **invocation_params} response = requests.post( self.eas_service_url, headers=headers, json=pload, stream=True ) for chunk in response.iter_lines( chunk_size=8192, decode_unicode=False, delimiter=b"\0" ): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["response"] # identify stop sequence in generated text, if any stop_seq_found: Optional[str] = None for stop_seq in invocation_params["stop"]: if stop_seq in output: stop_seq_found = stop_seq # identify text to yield text: Optional[str] = None if stop_seq_found: text = output[: output.index(stop_seq_found)] else: text = output # yield text, if any if text: res = GenerationChunk(text=text) if run_manager: run_manager.on_llm_new_token(res.text) yield res # break if stop sequence found if stop_seq_found: break