Source code for langchain_community.llms.koboldai
import logging
from typing import Any, Dict, List, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
logger = logging.getLogger(__name__)
[docs]def clean_url(url: str) -> str:
"""如果存在的话,删除URL末尾的斜杠和/api。"""
if url.endswith("/api"):
return url[:-4]
elif url.endswith("/"):
return url[:-1]
else:
return url
[docs]class KoboldApiLLM(LLM):
"""Kobold API语言模型。
它包括几个字段,可用于控制文本生成过程。
要使用此类,请使用所需的参数实例化它,并使用提示调用它以生成文本。例如:
kobold = KoboldApiLLM(endpoint="http://localhost:5000")
result = kobold("Write a story about a dragon.")
这将向Kobold API发送一个带有提供的提示的POST请求,并生成文本。"""
endpoint: str
"""用于生成文本的API端点。"""
use_story: Optional[bool] = False
"""在生成文本时是否使用KoboldAI GUI中的故事。"""
use_authors_note: Optional[bool] = False
"""是否在生成文本时使用KoboldAI GUI中的作者注释。
除非同时启用use_story,否则这没有任何效果。"""
use_world_info: Optional[bool] = False
"""在生成文本时是否使用KoboldAI GUI中的世界信息。"""
use_memory: Optional[bool] = False
"""在生成文本时是否使用KoboldAI GUI 中的内存。"""
max_context_length: Optional[int] = 1600
"""将要发送给模型的最大令牌数量。
最小值:1"""
max_length: Optional[int] = 80
"""需要生成的令牌数量。
最大值:512
最小值:1"""
rep_pen: Optional[float] = 1.12
"""基础重复惩罚值。
最小值:1"""
rep_pen_range: Optional[int] = 1024
"""重复惩罚范围。
最小值:0"""
rep_pen_slope: Optional[float] = 0.9
"""重复惩罚斜率。
最小值: 0"""
temperature: Optional[float] = 0.6
"""温度数值。
exclusiveMinimum: 0"""
tfs: Optional[float] = 0.9
"""尾部自由抽样数值。
最大值:1
最小值:0"""
top_a: Optional[float] = 0.9
"""顶部-a抽样值。
最小值:0"""
top_p: Optional[float] = 0.95
"""顶部p采样值。
最大值:1
最小值:0"""
top_k: Optional[int] = 0
"""Top-k抽样数值。
最小值:0"""
typical: Optional[float] = 0.5
"""典型的采样值。
最大值:1
最小值:0"""
@property
def _llm_type(self) -> str:
return "koboldai"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用API并返回输出。
参数:
prompt: 用于生成的提示。
stop: 遇到时停止生成的字符串列表。
返回:
生成的文本。
示例:
.. code-block:: python
from langchain_community.llms import KoboldApiLLM
llm = KoboldApiLLM(endpoint="http://localhost:5000")
llm.invoke("Write a story about dragons.")
"""
data: Dict[str, Any] = {
"prompt": prompt,
"use_story": self.use_story,
"use_authors_note": self.use_authors_note,
"use_world_info": self.use_world_info,
"use_memory": self.use_memory,
"max_context_length": self.max_context_length,
"max_length": self.max_length,
"rep_pen": self.rep_pen,
"rep_pen_range": self.rep_pen_range,
"rep_pen_slope": self.rep_pen_slope,
"temperature": self.temperature,
"tfs": self.tfs,
"top_a": self.top_a,
"top_p": self.top_p,
"top_k": self.top_k,
"typical": self.typical,
}
if stop is not None:
data["stop_sequence"] = stop
response = requests.post(
f"{clean_url(self.endpoint)}/api/v1/generate", json=data
)
response.raise_for_status()
json_response = response.json()
if (
"results" in json_response
and len(json_response["results"]) > 0
and "text" in json_response["results"][0]
):
text = json_response["results"][0]["text"].strip()
if stop is not None:
for sequence in stop:
if text.endswith(sequence):
text = text[: -len(sequence)].rstrip()
return text
else:
raise ValueError(
f"Unexpected response format from Kobold API: {json_response}"
)