import json
import logging
from typing import Any, Dict, List, Mapping, Optional, Set
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field
from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
[docs]class Yuan2(LLM):
"""Yuan2.0 语言模型。
示例:
.. code-block:: python
yuan_llm = Yuan2(
infer_api="http://127.0.0.1:8000/yuan",
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
)
print(yuan_llm)
print(yuan_llm.invoke("你是谁?"))
"""
infer_api: str = "http://127.0.0.1:8000/yuan"
"""Yuan2.0 推理 API"""
max_tokens: int = Field(1024, alias="max_token")
"""标记上下文窗口。"""
temp: Optional[float] = 0.7
"""用于采样的温度。"""
top_p: Optional[float] = 0.9
"""用于抽样的顶部p值。"""
top_k: Optional[int] = 0
"""用于采样的前k个值。"""
do_sample: bool = False
"""do_sample是一个布尔值,用于确定在文本生成过程中是否使用采样方法。"""
echo: Optional[bool] = False
"""是否回显提示符。"""
stop: Optional[List[str]] = []
"""遇到时停止生成的字符串列表。"""
repeat_last_n: Optional[int] = 64
"最后n个标记以进行惩罚"
repeat_penalty: Optional[float] = 1.18
"""重复标记的惩罚。"""
streaming: bool = False
"""是否要流式传输结果。"""
history: List[str] = []
"""对话历史"""
use_history: bool = False
"""是否使用历史记录。"""
def __init__(self, **kwargs: Any) -> None:
"""初始化Yuan2类。"""
super().__init__(**kwargs)
if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
logger.warning(
"top_p and top_k cannot be set simultaneously. "
"set top_k to 0 instead..."
)
self.top_k = 0
@property
def _llm_type(self) -> str:
return "Yuan2.0"
@staticmethod
def _model_param_names() -> Set[str]:
return {
"max_tokens",
"temp",
"top_k",
"top_p",
"do_sample",
}
def _default_params(self) -> Dict[str, Any]:
return {
"do_sample": self.do_sample,
"infer_api": self.infer_api,
"max_tokens": self.max_tokens,
"repeat_penalty": self.repeat_penalty,
"temp": self.temp,
"top_k": self.top_k,
"top_p": self.top_p,
"use_history": self.use_history,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {
"model": self._llm_type,
**self._default_params(),
**{
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用 Yuan2.0 LLM 推理端点。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停止词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = yuan_llm.invoke("你能做什么?")
"""
if self.use_history:
self.history.append(prompt)
input = "<n>".join(self.history)
else:
input = prompt
headers = {"Content-Type": "application/json"}
data = json.dumps(
{
"ques_list": [{"id": "000", "ques": input}],
"tokens_to_generate": self.max_tokens,
"temperature": self.temp,
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
}
)
logger.debug("Yuan2.0 prompt:", input)
# call api
try:
response = requests.put(self.infer_api, headers=headers, data=data)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference api: {e}")
logger.debug(f"Yuan2.0 response: {response}")
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
try:
resp = response.json()
if resp["errCode"] != "0":
raise ValueError(
f"Failed with error code [{resp['errCode']}], "
f"error message: [{resp['exceptionMsg']}]"
)
if "resData" in resp:
if len(resp["resData"]["output"]) >= 0:
generate_text = resp["resData"]["output"][0]["ans"]
else:
raise ValueError("No output found in response.")
else:
raise ValueError("No resData found in response.")
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference api: {e}."
f"\nResponse: {response.text}"
)
if stop is not None:
generate_text = enforce_stop_tokens(generate_text, stop)
# support multi-turn chat
if self.use_history:
self.history.append(generate_text)
logger.debug(f"history: {self.history}")
return generate_text