Source code for langchain_community.llms.textgen

import json
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

import requests
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Field

logger = logging.getLogger(__name__)


[docs]class TextGen(LLM): """从WebUI生成文本模型。 要使用,您应该安装text-generation-webui,加载一个模型,并将--api添加为命令行选项。 建议安装方法,使用一键安装程序安装到您的操作系统: https://github.com/oobabooga/text-generation-webui#one-click-installers 下面的参数取自text-generation-webui api示例: https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py 示例: .. code-block:: python from langchain_community.llms import TextGen llm = TextGen(model_url="http://localhost:8500")""" model_url: str """完整的textgen webui的URL,包括http[s]://host:port""" preset: Optional[str] = None """在textgen webui 中使用的预设值""" max_new_tokens: Optional[int] = 250 """生成的最大令牌数量。""" do_sample: bool = Field(True, alias="do_sample") """执行示例""" temperature: Optional[float] = 1.3 """控制输出随机性的主要因素。0 = 确定性(仅使用最可能的标记)。较高的值 = 更多的随机性。""" top_p: Optional[float] = 0.1 """如果未设置为1,则选择概率相加小于此数字的令牌。较高的值=更广泛的可能随机结果范围。""" typical_p: Optional[float] = 1 """如果未设置为1,则仅选择比随机令牌更有可能出现的令牌,给定先前文本。""" epsilon_cutoff: Optional[float] = 0 # In units of 1e-4 """Epsilon截断""" eta_cutoff: Optional[float] = 0 # In units of 1e-4 """ETA截止时间""" repetition_penalty: Optional[float] = 1.18 """重复先前标记的指数惩罚因子。1表示没有惩罚,值越高 = 重复越少,值越低 = 重复越多。""" top_k: Optional[float] = 40 """类似于top_p,但只选择最有可能的前k个标记。 较高的值=更高范围的可能随机结果。""" min_length: Optional[int] = 0 """令牌中的最小生成长度。""" no_repeat_ngram_size: Optional[int] = 0 """如果未设置为0,则指定完全阻止重复的令牌集的长度。较高的值=阻止更大的短语,较低的值=阻止单词或字母的重复。在大多数情况下,只有0或较高的值是一个好主意。""" num_beams: Optional[int] = 1 """梁的数量""" penalty_alpha: Optional[float] = 0 """惩罚项 Alpha""" length_penalty: Optional[float] = 1 """长度惩罚""" early_stopping: bool = Field(False, alias="early_stopping") """提前停止""" seed: int = Field(-1, alias="seed") """种子(-1表示随机)""" add_bos_token: bool = Field(True, alias="add_bos_token") """将bos_token添加到提示的开头。 禁用此选项可以使回复更具创意。""" truncation_length: Optional[int] = 2048 """将提示截断到这个长度。如果提示超过这个长度,将删除最左边的标记。大多数模型要求这个长度最多为2048。""" ban_eos_token: bool = Field(False, alias="ban_eos_token") """禁止使用eos_token。强制模型永远不会过早结束生成。""" skip_special_tokens: bool = Field(True, alias="skip_special_tokens") """跳过特殊标记。某些特定模型需要取消设置此选项。""" stopping_strings: Optional[List[str]] = [] """遇到时停止生成的字符串列表。""" streaming: bool = False """是否逐个标记流式传输结果。""" @property def _default_params(self) -> Dict[str, Any]: """获取调用textgen时的默认参数。""" return { "max_new_tokens": self.max_new_tokens, "do_sample": self.do_sample, "temperature": self.temperature, "top_p": self.top_p, "typical_p": self.typical_p, "epsilon_cutoff": self.epsilon_cutoff, "eta_cutoff": self.eta_cutoff, "repetition_penalty": self.repetition_penalty, "top_k": self.top_k, "min_length": self.min_length, "no_repeat_ngram_size": self.no_repeat_ngram_size, "num_beams": self.num_beams, "penalty_alpha": self.penalty_alpha, "length_penalty": self.length_penalty, "early_stopping": self.early_stopping, "seed": self.seed, "add_bos_token": self.add_bos_token, "truncation_length": self.truncation_length, "ban_eos_token": self.ban_eos_token, "skip_special_tokens": self.skip_special_tokens, "stopping_strings": self.stopping_strings, } @property def _identifying_params(self) -> Dict[str, Any]: """获取识别参数。""" return {**{"model_url": self.model_url}, **self._default_params} @property def _llm_type(self) -> str: """llm的返回类型。""" return "textgen" def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: """执行健全性检查,准备文本生成器所需格式的参数。 参数: stop (Optional[List[str]]): 用于文本生成器的停止序列列表。 返回: 包含组合参数的字典。 """ # Raise error if stop sequences are in both input and default params # if self.stop and stop is not None: if self.stopping_strings and stop is not None: raise ValueError("`stop` found in both the input and default params.") if self.preset is None: params = self._default_params else: params = {"preset": self.preset} # then sets it as configured, or default to an empty list: params["stopping_strings"] = self.stopping_strings or stop or [] return params def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用textgen web API 并返回输出。 参数: prompt: 用于生成的提示。 stop: 遇到时停止生成的字符串列表。 返回: 生成的文本。 示例: .. code-block:: python from langchain_community.llms import TextGen llm = TextGen(model_url="http://localhost:5000") llm.invoke("Write a story about llamas.") """ if self.streaming: combined_text_output = "" for chunk in self._stream( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ): combined_text_output += chunk.text result = combined_text_output else: url = f"{self.model_url}/api/v1/generate" params = self._get_parameters(stop) request = params.copy() request["prompt"] = prompt response = requests.post(url, json=request) if response.status_code == 200: result = response.json()["results"][0]["text"] else: print(f"ERROR: response: {response}") # noqa: T201 result = "" return result async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用textgen web API 并返回输出。 参数: prompt: 用于生成的提示。 stop: 遇到时停止生成的字符串列表。 返回: 生成的文本。 示例: .. code-block:: python from langchain_community.llms import TextGen llm = TextGen(model_url="http://localhost:5000") llm.invoke("Write a story about llamas.") """ if self.streaming: combined_text_output = "" async for chunk in self._astream( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ): combined_text_output += chunk.text result = combined_text_output else: url = f"{self.model_url}/api/v1/generate" params = self._get_parameters(stop) request = params.copy() request["prompt"] = prompt response = requests.post(url, json=request) if response.status_code == 200: result = response.json()["results"][0]["text"] else: print(f"ERROR: response: {response}") # noqa: T201 result = "" return result def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """以实时生成的方式产生结果对象。 它还使用类似于OpenAI LLM类同名方法的参数调用回调管理器的on_llm_new_token事件。 参数: prompt: 传递给模型的提示。 stop: 生成时使用的可选停止词列表。 返回: 表示正在生成的标记流的生成器。 产生: 类似于包含字符串标记和元数据的对象的字典。 有关更多信息,请参阅text-generation-webui文档和下面的内容。 示例: .. code-block:: python from langchain_community.llms import TextGen llm = TextGen( model_url = "ws://localhost:5005" streaming=True ) for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", stop=["'"," "]): print(chunk, end='', flush=True) # noqa: T201 """ try: import websocket except ImportError: raise ImportError( "The `websocket-client` package is required for streaming." ) params = {**self._get_parameters(stop), **kwargs} url = f"{self.model_url}/api/v1/stream" request = params.copy() request["prompt"] = prompt websocket_client = websocket.WebSocket() websocket_client.connect(url) websocket_client.send(json.dumps(request)) while True: result = websocket_client.recv() result = json.loads(result) if result["event"] == "text_stream": # type: ignore[call-overload, index] chunk = GenerationChunk( text=result["text"], # type: ignore[call-overload, index] generation_info=None, ) yield chunk elif result["event"] == "stream_end": # type: ignore[call-overload, index] websocket_client.close() return if run_manager: run_manager.on_llm_new_token(token=chunk.text) async def _astream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: """以实时生成的方式产生结果对象。 它还使用类似于OpenAI LLM类同名方法的参数调用回调管理器的on_llm_new_token事件。 参数: prompt: 传递给模型的提示。 stop: 生成时使用的可选停止词列表。 返回: 表示正在生成的标记流的生成器。 产生: 类似于包含字符串标记和元数据的对象的字典。 有关更多信息,请参阅text-generation-webui文档和下面的内容。 示例: .. code-block:: python from langchain_community.llms import TextGen llm = TextGen( model_url = "ws://localhost:5005" streaming=True ) for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", stop=["'"," "]): print(chunk, end='', flush=True) # noqa: T201 """ try: import websocket except ImportError: raise ImportError( "The `websocket-client` package is required for streaming." ) params = {**self._get_parameters(stop), **kwargs} url = f"{self.model_url}/api/v1/stream" request = params.copy() request["prompt"] = prompt websocket_client = websocket.WebSocket() websocket_client.connect(url) websocket_client.send(json.dumps(request)) while True: result = websocket_client.recv() result = json.loads(result) if result["event"] == "text_stream": # type: ignore[call-overload, index] chunk = GenerationChunk( text=result["text"], # type: ignore[call-overload, index] generation_info=None, ) yield chunk elif result["event"] == "stream_end": # type: ignore[call-overload, index] websocket_client.close() return if run_manager: await run_manager.on_llm_new_token(token=chunk.text)