Source code for langchain_community.llms.yuan2

import json
import logging
from typing import Any, Dict, List, Mapping, Optional, Set

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field

from langchain_community.llms.utils import enforce_stop_tokens

logger = logging.getLogger(__name__)


[docs]class Yuan2(LLM): """Yuan2.0 语言模型。 示例: .. code-block:: python yuan_llm = Yuan2( infer_api="http://127.0.0.1:8000/yuan", max_tokens=1024, temp=1.0, top_p=0.9, top_k=40, ) print(yuan_llm) print(yuan_llm.invoke("你是谁?")) """ infer_api: str = "http://127.0.0.1:8000/yuan" """Yuan2.0 推理 API""" max_tokens: int = Field(1024, alias="max_token") """标记上下文窗口。""" temp: Optional[float] = 0.7 """用于采样的温度。""" top_p: Optional[float] = 0.9 """用于抽样的顶部p值。""" top_k: Optional[int] = 0 """用于采样的前k个值。""" do_sample: bool = False """do_sample是一个布尔值,用于确定在文本生成过程中是否使用采样方法。""" echo: Optional[bool] = False """是否回显提示符。""" stop: Optional[List[str]] = [] """遇到时停止生成的字符串列表。""" repeat_last_n: Optional[int] = 64 "最后n个标记以进行惩罚" repeat_penalty: Optional[float] = 1.18 """重复标记的惩罚。""" streaming: bool = False """是否要流式传输结果。""" history: List[str] = [] """对话历史""" use_history: bool = False """是否使用历史记录。""" def __init__(self, **kwargs: Any) -> None: """初始化Yuan2类。""" super().__init__(**kwargs) if (self.top_p or 0) > 0 and (self.top_k or 0) > 0: logger.warning( "top_p and top_k cannot be set simultaneously. " "set top_k to 0 instead..." ) self.top_k = 0 @property def _llm_type(self) -> str: return "Yuan2.0" @staticmethod def _model_param_names() -> Set[str]: return { "max_tokens", "temp", "top_k", "top_p", "do_sample", } def _default_params(self) -> Dict[str, Any]: return { "do_sample": self.do_sample, "infer_api": self.infer_api, "max_tokens": self.max_tokens, "repeat_penalty": self.repeat_penalty, "temp": self.temp, "top_k": self.top_k, "top_p": self.top_p, "use_history": self.use_history, } @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { "model": self._llm_type, **self._default_params(), **{ k: v for k, v in self.__dict__.items() if k in self._model_param_names() }, } def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用 Yuan2.0 LLM 推理端点。 参数: prompt: 传递给模型的提示。 stop: 生成时可选的停止词列表。 返回: 模型生成的字符串。 示例: .. code-block:: python response = yuan_llm.invoke("你能做什么?") """ if self.use_history: self.history.append(prompt) input = "<n>".join(self.history) else: input = prompt headers = {"Content-Type": "application/json"} data = json.dumps( { "ques_list": [{"id": "000", "ques": input}], "tokens_to_generate": self.max_tokens, "temperature": self.temp, "top_p": self.top_p, "top_k": self.top_k, "do_sample": self.do_sample, } ) logger.debug("Yuan2.0 prompt:", input) # call api try: response = requests.put(self.infer_api, headers=headers, data=data) except requests.exceptions.RequestException as e: raise ValueError(f"Error raised by inference api: {e}") logger.debug(f"Yuan2.0 response: {response}") if response.status_code != 200: raise ValueError(f"Failed with response: {response}") try: resp = response.json() if resp["errCode"] != "0": raise ValueError( f"Failed with error code [{resp['errCode']}], " f"error message: [{resp['exceptionMsg']}]" ) if "resData" in resp: if len(resp["resData"]["output"]) >= 0: generate_text = resp["resData"]["output"][0]["ans"] else: raise ValueError("No output found in response.") else: raise ValueError("No resData found in response.") except requests.exceptions.JSONDecodeError as e: raise ValueError( f"Error raised during decoding response from inference api: {e}." f"\nResponse: {response.text}" ) if stop is not None: generate_text = enforce_stop_tokens(generate_text, stop) # support multi-turn chat if self.use_history: self.history.append(generate_text) logger.debug(f"history: {self.history}") return generate_text