import json
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Field
logger = logging.getLogger(__name__)
[docs]class TextGen(LLM):
"""从WebUI生成文本模型。
要使用,您应该安装text-generation-webui,加载一个模型,并将--api添加为命令行选项。
建议安装方法,使用一键安装程序安装到您的操作系统:
https://github.com/oobabooga/text-generation-webui#one-click-installers
下面的参数取自text-generation-webui api示例:
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py
示例:
.. code-block:: python
from langchain_community.llms import TextGen
llm = TextGen(model_url="http://localhost:8500")"""
model_url: str
"""完整的textgen webui的URL,包括http[s]://host:port"""
preset: Optional[str] = None
"""在textgen webui 中使用的预设值"""
max_new_tokens: Optional[int] = 250
"""生成的最大令牌数量。"""
do_sample: bool = Field(True, alias="do_sample")
"""执行示例"""
temperature: Optional[float] = 1.3
"""控制输出随机性的主要因素。0 = 确定性(仅使用最可能的标记)。较高的值 = 更多的随机性。"""
top_p: Optional[float] = 0.1
"""如果未设置为1,则选择概率相加小于此数字的令牌。较高的值=更广泛的可能随机结果范围。"""
typical_p: Optional[float] = 1
"""如果未设置为1,则仅选择比随机令牌更有可能出现的令牌,给定先前文本。"""
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4
"""Epsilon截断"""
eta_cutoff: Optional[float] = 0 # In units of 1e-4
"""ETA截止时间"""
repetition_penalty: Optional[float] = 1.18
"""重复先前标记的指数惩罚因子。1表示没有惩罚,值越高 = 重复越少,值越低 = 重复越多。"""
top_k: Optional[float] = 40
"""类似于top_p,但只选择最有可能的前k个标记。
较高的值=更高范围的可能随机结果。"""
min_length: Optional[int] = 0
"""令牌中的最小生成长度。"""
no_repeat_ngram_size: Optional[int] = 0
"""如果未设置为0,则指定完全阻止重复的令牌集的长度。较高的值=阻止更大的短语,较低的值=阻止单词或字母的重复。在大多数情况下,只有0或较高的值是一个好主意。"""
num_beams: Optional[int] = 1
"""梁的数量"""
penalty_alpha: Optional[float] = 0
"""惩罚项 Alpha"""
length_penalty: Optional[float] = 1
"""长度惩罚"""
early_stopping: bool = Field(False, alias="early_stopping")
"""提前停止"""
seed: int = Field(-1, alias="seed")
"""种子(-1表示随机)"""
add_bos_token: bool = Field(True, alias="add_bos_token")
"""将bos_token添加到提示的开头。
禁用此选项可以使回复更具创意。"""
truncation_length: Optional[int] = 2048
"""将提示截断到这个长度。如果提示超过这个长度,将删除最左边的标记。大多数模型要求这个长度最多为2048。"""
ban_eos_token: bool = Field(False, alias="ban_eos_token")
"""禁止使用eos_token。强制模型永远不会过早结束生成。"""
skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
"""跳过特殊标记。某些特定模型需要取消设置此选项。"""
stopping_strings: Optional[List[str]] = []
"""遇到时停止生成的字符串列表。"""
streaming: bool = False
"""是否逐个标记流式传输结果。"""
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用textgen时的默认参数。"""
return {
"max_new_tokens": self.max_new_tokens,
"do_sample": self.do_sample,
"temperature": self.temperature,
"top_p": self.top_p,
"typical_p": self.typical_p,
"epsilon_cutoff": self.epsilon_cutoff,
"eta_cutoff": self.eta_cutoff,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"min_length": self.min_length,
"no_repeat_ngram_size": self.no_repeat_ngram_size,
"num_beams": self.num_beams,
"penalty_alpha": self.penalty_alpha,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"seed": self.seed,
"add_bos_token": self.add_bos_token,
"truncation_length": self.truncation_length,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"stopping_strings": self.stopping_strings,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""获取识别参数。"""
return {**{"model_url": self.model_url}, **self._default_params}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "textgen"
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""执行健全性检查,准备文本生成器所需格式的参数。
参数:
stop (Optional[List[str]]): 用于文本生成器的停止序列列表。
返回:
包含组合参数的字典。
"""
# Raise error if stop sequences are in both input and default params
# if self.stop and stop is not None:
if self.stopping_strings and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
if self.preset is None:
params = self._default_params
else:
params = {"preset": self.preset}
# then sets it as configured, or default to an empty list:
params["stopping_strings"] = self.stopping_strings or stop or []
return params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用textgen web API 并返回输出。
参数:
prompt: 用于生成的提示。
stop: 遇到时停止生成的字符串列表。
返回:
生成的文本。
示例:
.. code-block:: python
from langchain_community.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm.invoke("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
result = combined_text_output
else:
url = f"{self.model_url}/api/v1/generate"
params = self._get_parameters(stop)
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()["results"][0]["text"]
else:
print(f"ERROR: response: {response}") # noqa: T201
result = ""
return result
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用textgen web API 并返回输出。
参数:
prompt: 用于生成的提示。
stop: 遇到时停止生成的字符串列表。
返回:
生成的文本。
示例:
.. code-block:: python
from langchain_community.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm.invoke("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
result = combined_text_output
else:
url = f"{self.model_url}/api/v1/generate"
params = self._get_parameters(stop)
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()["results"][0]["text"]
else:
print(f"ERROR: response: {response}") # noqa: T201
result = ""
return result
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""以实时生成的方式产生结果对象。
它还使用类似于OpenAI LLM类同名方法的参数调用回调管理器的on_llm_new_token事件。
参数:
prompt: 传递给模型的提示。
stop: 生成时使用的可选停止词列表。
返回:
表示正在生成的标记流的生成器。
产生:
类似于包含字符串标记和元数据的对象的字典。
有关更多信息,请参阅text-generation-webui文档和下面的内容。
示例:
.. code-block:: python
from langchain_community.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","
"]):
print(chunk, end='', flush=True) # noqa: T201
"""
try:
import websocket
except ImportError:
raise ImportError(
"The `websocket-client` package is required for streaming."
)
params = {**self._get_parameters(stop), **kwargs}
url = f"{self.model_url}/api/v1/stream"
request = params.copy()
request["prompt"] = prompt
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream": # type: ignore[call-overload, index]
chunk = GenerationChunk(
text=result["text"], # type: ignore[call-overload, index]
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end": # type: ignore[call-overload, index]
websocket_client.close()
return
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""以实时生成的方式产生结果对象。
它还使用类似于OpenAI LLM类同名方法的参数调用回调管理器的on_llm_new_token事件。
参数:
prompt: 传递给模型的提示。
stop: 生成时使用的可选停止词列表。
返回:
表示正在生成的标记流的生成器。
产生:
类似于包含字符串标记和元数据的对象的字典。
有关更多信息,请参阅text-generation-webui文档和下面的内容。
示例:
.. code-block:: python
from langchain_community.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","
"]):
print(chunk, end='', flush=True) # noqa: T201
"""
try:
import websocket
except ImportError:
raise ImportError(
"The `websocket-client` package is required for streaming."
)
params = {**self._get_parameters(stop), **kwargs}
url = f"{self.model_url}/api/v1/stream"
request = params.copy()
request["prompt"] = prompt
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream": # type: ignore[call-overload, index]
chunk = GenerationChunk(
text=result["text"], # type: ignore[call-overload, index]
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end": # type: ignore[call-overload, index]
websocket_client.close()
return
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text)