Source code for langchain_community.llms.gpt4all

from functools import partial
from typing import Any, Dict, List, Mapping, Optional, Set

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, root_validator

from langchain_community.llms.utils import enforce_stop_tokens


[docs]class GPT4All(LLM): """GPT4All语言模型。 要使用,您应该安装``gpt4all`` python包,预训练模型文件和模型的配置信息。 示例: .. code-block:: python from langchain_community.llms import GPT4All model = GPT4All(model="./models/gpt4all-model.bin", n_threads=8) # 最简单的调用 response = model.invoke("从前,有一段时间,")""" model: str """预训练的GPT4All模型文件路径。""" backend: Optional[str] = Field(None, alias="backend") max_tokens: int = Field(200, alias="max_tokens") """标记上下文窗口。""" n_parts: int = Field(-1, alias="n_parts") """将模型分割成的部分数量。 如果是-1,则部分数量会自动确定。""" seed: int = Field(0, alias="seed") """种子。如果为-1,则使用随机种子。""" f16_kv: bool = Field(False, alias="f16_kv") """为键/值缓存使用半精度。""" logits_all: bool = Field(False, alias="logits_all") """返回所有标记的logits,而不仅仅是最后一个标记。""" vocab_only: bool = Field(False, alias="vocab_only") """只加载词汇表,不加载权重。""" use_mlock: bool = Field(False, alias="use_mlock") """强制系统将模型保留在内存中。""" embedding: bool = Field(False, alias="embedding") """仅使用嵌入模式。""" n_threads: Optional[int] = Field(4, alias="n_threads") """要使用的线程数。""" n_predict: Optional[int] = 256 """生成的最大令牌数量。""" temp: Optional[float] = 0.7 """用于采样的温度。""" top_p: Optional[float] = 0.1 """用于抽样的顶部p值。""" top_k: Optional[int] = 40 """用于采样的前k个值。""" echo: Optional[bool] = False """是否回显提示符。""" stop: Optional[List[str]] = [] """遇到时停止生成的字符串列表。""" repeat_last_n: Optional[int] = 64 "最后n个标记以进行惩罚" repeat_penalty: Optional[float] = 1.18 """重复标记的惩罚。""" n_batch: int = Field(8, alias="n_batch") """用于提示处理的批处理大小。""" streaming: bool = False """是否要流式传输结果。""" allow_download: bool = False """如果模型在~/.cache/gpt4all/中不存在,则下载它。""" device: Optional[str] = Field("cpu", alias="device") """设备名称:cpu,gpu,nvidia,intel,amd或DeviceName。""" client: Any = None #: :meta private: class Config: """此pydantic对象的配置。""" extra = Extra.forbid @staticmethod def _model_param_names() -> Set[str]: return { "max_tokens", "n_predict", "top_k", "top_p", "temp", "n_batch", "repeat_penalty", "repeat_last_n", "streaming", } def _default_params(self) -> Dict[str, Any]: return { "max_tokens": self.max_tokens, "n_predict": self.n_predict, "top_k": self.top_k, "top_p": self.top_p, "temp": self.temp, "n_batch": self.n_batch, "repeat_penalty": self.repeat_penalty, "repeat_last_n": self.repeat_last_n, "streaming": self.streaming, } @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证Python包是否存在于环境中。""" try: from gpt4all import GPT4All as GPT4AllModel except ImportError: raise ImportError( "Could not import gpt4all python package. " "Please install it with `pip install gpt4all`." ) full_path = values["model"] model_path, delimiter, model_name = full_path.rpartition("/") model_path += delimiter values["client"] = GPT4AllModel( model_name, model_path=model_path or None, model_type=values["backend"], allow_download=values["allow_download"], device=values["device"], ) if values["n_threads"] is not None: # set n_threads values["client"].model.set_thread_count(values["n_threads"]) try: values["backend"] = values["client"].model_type except AttributeError: # The below is for compatibility with GPT4All Python bindings <= 0.2.3. values["backend"] = values["client"].model.model_type return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { "model": self.model, **self._default_params(), **{ k: v for k, v in self.__dict__.items() if k in self._model_param_names() }, } @property def _llm_type(self) -> str: """返回llm的类型。""" return "gpt4all" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: r"""调用GPT4All的generate方法。 参数: prompt: 传递给模型的提示。 stop: 遇到时停止生成的字符串列表。 返回: 模型生成的字符串。 示例: .. code-block:: python prompt = "从前,有一只小猫," response = model.invoke(prompt, n_predict=55) """ text_callback = None if run_manager: text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) text = "" params = {**self._default_params(), **kwargs} for token in self.client.generate(prompt, **params): if text_callback: text_callback(token) text += token if stop is not None: text = enforce_stop_tokens(text, stop) return text