from __future__ import annotations
import json
from io import StringIO
from typing import Any, Dict, Iterator, List, Optional
import requests
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra
from langchain_core.utils import get_pydantic_field_names
[docs]class Llamafile(LLM):
"""Llamafile允许您使用单个文件分发和运行大型语言模型。
要开始,请参阅:https://github.com/Mozilla-Ocho/llamafile
要使用这个类,您首先需要:
1. 下载一个llamafile。
2. 使下载的文件可执行:`chmod +x path/to/model.llamafile`
3. 启动llamafile的服务器模式:
`./path/to/model.llamafile --server --nobrowser`
示例:
.. code-block:: python
from langchain_community.llms import Llamafile
llm = Llamafile()
llm.invoke("Tell me a joke.")"""
base_url: str = "http://localhost:8080"
"""LLamafile服务器正在监听的基本URL。"""
request_timeout: Optional[int] = None
"""服务器请求的超时时间"""
streaming: bool = False
"""允许实时接收每个预测的标记,而不是等待完成。要启用此功能,请设置为true。"""
# Generation options
seed: int = -1
"""随机数生成器(RNG)种子。如果小于零,则使用随机种子。默认值:-1"""
temperature: float = 0.8
"""温度。默认值:0.8"""
top_k: int = 40
"""将下一个标记的选择限制在K个最有可能的标记中。
默认值:40。"""
top_p: float = 0.95
"""将下一个令牌的选择限制在累积概率高于阈值P的令牌子集中。默认值为0.95。"""
min_p: float = 0.05
"""最小概率,用于考虑一个标记的相对概率,相对于最有可能的标记的概率。默认值为0.05。"""
n_predict: int = -1
"""设置生成文本时要预测的最大标记数。
注意:如果最后一个标记是部分多字节字符,则可能会略微超过设置的限制。
当为0时,不会生成任何标记,但提示将被评估到缓存中。默认值:-1 = 无限。"""
n_keep: int = 0
"""指定当上下文大小超过并且需要丢弃标记时要保留的标记数量。默认情况下,此值设置为0(表示不保留任何标记)。使用-1保留来自提示的所有标记。"""
tfs_z: float = 1.0
"""启用尾部免费采样,参数为z。默认值:1.0 = 禁用。"""
typical_p: float = 1.0
"""启用具有参数p的本地典型抽样。
默认值:1.0 = 禁用。"""
repeat_penalty: float = 1.1
"""控制生成文本中令牌序列的重复。默认值为1.1。"""
repeat_last_n: int = 64
"""最后考虑用于惩罚重复的n个标记。默认值为64,
0 = 禁用,-1 = 上下文大小。"""
penalize_nl: bool = True
"""在应用重复惩罚时,对换行符令牌进行惩罚。
默认值:true。"""
presence_penalty: float = 0.0
"""重复alpha存在惩罚。默认值:0.0 = 禁用。"""
frequency_penalty: float = 0.0
"""重复alpha频率惩罚。默认值:0.0 = 禁用"""
mirostat: int = 0
"""启用Mirostat采样,在文本生成过程中控制困惑度。0 = 禁用,1 = Mirostat,2 = Mirostat 2.0。默认值:禁用。"""
mirostat_tau: float = 5.0
"""设置Mirostat目标熵参数tau。默认值为5.0。"""
mirostat_eta: float = 0.1
"""设置Mirostat学习率参数eta。默认值为0.1。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
return "llamafile"
@property
def _param_fieldnames(self) -> List[str]:
# Return the list of fieldnames that will be passed as configurable
# generation options to the llamafile server. Exclude 'builtin' fields
# from the BaseLLM class like 'metadata' as well as fields that should
# not be passed in requests (base_url, request_timeout).
ignore_keys = [
"base_url",
"cache",
"callback_manager",
"callbacks",
"metadata",
"name",
"request_timeout",
"streaming",
"tags",
"verbose",
"custom_get_token_ids",
]
attrs = [
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
]
return attrs
@property
def _default_params(self) -> Dict[str, Any]:
params = {}
for fieldname in self._param_fieldnames:
params[fieldname] = getattr(self, fieldname)
return params
def _get_parameters(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
params = self._default_params
# Only update keys that are already present in the default config.
# This way, we don't accidentally post unknown/unhandled key/values
# in the request to the llamafile server
for k, v in kwargs.items():
if k in params:
params[k] = v
if stop is not None and len(stop) > 0:
params["stop"] = stop
if self.streaming:
params["stream"] = True
return params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""请求llamafile服务器完成提示,并返回输出。
参数:
prompt: 用于生成的提示。
stop: 遇到时停止生成的字符串列表。
run_manager:
**kwargs: 作为生成请求的一部分传递的任何其他选项。
返回:
模型生成的字符串。
"""
if self.streaming:
with StringIO() as buff:
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
buff.write(chunk.text)
text = buff.getvalue()
return text
else:
params = self._get_parameters(stop=stop, **kwargs)
payload = {"prompt": prompt, **params}
try:
response = requests.post(
url=f"{self.base_url}/completion",
headers={
"Content-Type": "application/json",
},
json=payload,
stream=False,
timeout=self.request_timeout,
)
except requests.exceptions.ConnectionError:
raise requests.exceptions.ConnectionError(
f"Could not connect to Llamafile server. Please make sure "
f"that a server is running at {self.base_url}."
)
response.raise_for_status()
response.encoding = "utf-8"
text = response.json()["content"]
return text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""以实时生成的方式产生结果对象。
它还调用回调管理器的on_llm_new_token事件,参数与OpenAI LLM类方法中的同名方法类似。
参数:
prompt: 传递给模型的提示。
stop: 生成时使用的可选停止词列表。
run_manager:
**kwargs: 作为生成请求的一部分传递的任何其他选项。
返回:
表示正在生成的标记流的生成器。
产生:
类似字典的对象,每个对象包含一个标记
示例:
.. code-block:: python
from langchain_community.llms import Llamafile
llm = Llamafile(
temperature = 0.0
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","
"]):
result = chunk["choices"][0]
print(result["text"], end='', flush=True)
"""
params = self._get_parameters(stop=stop, **kwargs)
if "stream" not in params:
params["stream"] = True
payload = {"prompt": prompt, **params}
try:
response = requests.post(
url=f"{self.base_url}/completion",
headers={
"Content-Type": "application/json",
},
json=payload,
stream=True,
timeout=self.request_timeout,
)
except requests.exceptions.ConnectionError:
raise requests.exceptions.ConnectionError(
f"Could not connect to Llamafile server. Please make sure "
f"that a server is running at {self.base_url}."
)
response.encoding = "utf8"
for raw_chunk in response.iter_lines(decode_unicode=True):
content = self._get_chunk_content(raw_chunk)
chunk = GenerationChunk(text=content)
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
yield chunk
def _get_chunk_content(self, chunk: str) -> str:
"""当流式传输打开时,llamafile服务器返回如下行:
'data: {"content":" They","multimodal":true,"slot_id":0,"stop":false}'
在这里,我们将其转换为字典并返回'content'字段的值。
"""
if chunk.startswith("data:"):
cleaned = chunk.lstrip("data: ")
data = json.loads(cleaned)
return data["content"]
else:
return chunk