Source code for langchain_community.llms.rwkv

"""RWKV模型。

基于 https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py
         https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
"""
from typing import Any, Dict, List, Mapping, Optional, Set

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

from langchain_community.llms.utils import enforce_stop_tokens


[docs]class RWKV(LLM, BaseModel): """RWKV 语言模型。 要使用,您应该安装``rwkv`` python包,预训练模型文件和模型的配置信息。 示例: .. code-block:: python from langchain_community.llms import RWKV model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32") # 最简单的调用 response = model.invoke("Once upon a time, ") """ model: str """预训练的RWKV模型文件路径。""" tokens_path: str """RWKV令牌文件的路径。""" strategy: str = "cpu fp32" """标记上下文窗口。""" rwkv_verbose: bool = True """打印调试信息。""" temperature: float = 1.0 """用于采样的温度。""" top_p: float = 0.5 """用于抽样的顶部p值。""" penalty_alpha_frequency: float = 0.4 """正值根据文本中新标记的现有频率对其进行惩罚,降低模型重复相同行的可能性。""" penalty_alpha_presence: float = 0.4 """正值根据新令牌是否出现在文本中对其进行惩罚,增加模型谈论新主题的可能性。""" CHUNK_LEN: int = 256 """用于提示处理的批处理大小。""" max_tokens_per_generation: int = 256 """生成的令牌的最大数量。""" client: Any = None #: :meta private: tokenizer: Any = None #: :meta private: pipeline: Any = None #: :meta private: model_tokens: Any = None #: :meta private: model_state: Any = None #: :meta private: class Config: """此pydantic对象的配置。""" extra = Extra.forbid @property def _default_params(self) -> Dict[str, Any]: """获取识别参数。""" return { "verbose": self.verbose, "top_p": self.top_p, "temperature": self.temperature, "penalty_alpha_frequency": self.penalty_alpha_frequency, "penalty_alpha_presence": self.penalty_alpha_presence, "CHUNK_LEN": self.CHUNK_LEN, "max_tokens_per_generation": self.max_tokens_per_generation, } @staticmethod def _rwkv_param_names() -> Set[str]: """获取识别参数。""" return { "verbose", } @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证Python包是否存在于环境中。""" try: import tokenizers except ImportError: raise ImportError( "Could not import tokenizers python package. " "Please install it with `pip install tokenizers`." ) try: from rwkv.model import RWKV as RWKVMODEL from rwkv.utils import PIPELINE values["tokenizer"] = tokenizers.Tokenizer.from_file(values["tokens_path"]) rwkv_keys = cls._rwkv_param_names() model_kwargs = {k: v for k, v in values.items() if k in rwkv_keys} model_kwargs["verbose"] = values["rwkv_verbose"] values["client"] = RWKVMODEL( values["model"], strategy=values["strategy"], **model_kwargs ) values["pipeline"] = PIPELINE(values["client"], values["tokens_path"]) except ImportError: raise ImportError( "Could not import rwkv python package. " "Please install it with `pip install rwkv`." ) return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { "model": self.model, **self._default_params, **{k: v for k, v in self.__dict__.items() if k in RWKV._rwkv_param_names()}, } @property def _llm_type(self) -> str: """返回llm的类型。""" return "rwkv"
[docs] def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any: AVOID_REPEAT_TOKENS = [] AVOID_REPEAT = ",:?!" for i in AVOID_REPEAT: dd = self.pipeline.encode(i) assert len(dd) == 1 AVOID_REPEAT_TOKENS += dd tokens = [int(x) for x in _tokens] self.model_tokens += tokens out: Any = None while len(tokens) > 0: out, self.model_state = self.client.forward( tokens[: self.CHUNK_LEN], self.model_state ) tokens = tokens[self.CHUNK_LEN :] END_OF_LINE = 187 out[END_OF_LINE] += newline_adj # adjust \n probability if self.model_tokens[-1] in AVOID_REPEAT_TOKENS: out[self.model_tokens[-1]] = -999999999 return out
[docs] def rwkv_generate(self, prompt: str) -> str: self.model_state = None self.model_tokens = [] logits = self.run_rnn(self.tokenizer.encode(prompt).ids) begin = len(self.model_tokens) out_last = begin occurrence: Dict = {} decoded = "" for i in range(self.max_tokens_per_generation): for n in occurrence: logits[n] -= ( self.penalty_alpha_presence + occurrence[n] * self.penalty_alpha_frequency ) token = self.pipeline.sample_logits( logits, temperature=self.temperature, top_p=self.top_p ) END_OF_TEXT = 0 if token == END_OF_TEXT: break if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 logits = self.run_rnn([token]) xxx = self.tokenizer.decode(self.model_tokens[out_last:]) if "\ufffd" not in xxx: # avoid utf-8 display issues decoded += xxx out_last = begin + i + 1 if i >= self.max_tokens_per_generation - 100: break return decoded
def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: r"""RWKV生成 参数: prompt:传递给模型的提示。 stop:遇到时停止生成的字符串列表。 返回: 模型生成的字符串。 示例: .. code-block:: python prompt = "从前,有一天," response = model.invoke(prompt, n_predict=55) """ text = self.rwkv_generate(prompt) if stop is not None: text = enforce_stop_tokens(text, stop) return text