import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Extra
def _stream_response_to_generation_chunk(
stream_response: str,
) -> GenerationChunk:
"""将流响应转换为生成块。"""
parsed_response = json.loads(stream_response)
generation_info = parsed_response if parsed_response.get("done") is True else None
return GenerationChunk(
text=parsed_response.get("response", ""), generation_info=generation_info
)
[docs]class OllamaEndpointNotFoundError(Exception):
"""当未找到Ollama端点时引发。"""
class _OllamaCommon(BaseLanguageModel):
base_url: str = "http://localhost:11434"
"""模型托管的基本URL。"""
model: str = "llama2"
"""要使用的模型名称。"""
mirostat: Optional[int] = None
"""启用Mirostat采样以控制困惑度。
(默认值:0,0 = 禁用,1 = Mirostat,2 = Mirostat 2.0)"""
mirostat_eta: Optional[float] = None
"""影响算法对生成文本的反馈作出响应的速度。较低的学习率会导致调整速度较慢,而较高的学习率会使算法更具响应性。(默认值:0.1)"""
mirostat_tau: Optional[float] = None
"""控制输出的一致性和多样性之间的平衡。较低的值会导致更加聚焦和连贯的文本。(默认值:5.0)"""
num_ctx: Optional[int] = None
"""设置用于生成下一个标记的上下文窗口的大小。(默认值:2048)"""
num_gpu: Optional[int] = None
"""要使用的GPU数量。在macOS上,默认值为1,以启用Metal支持,为0则禁用。"""
num_thread: Optional[int] = None
"""设置计算过程中要使用的线程数。
默认情况下,Ollama会检测以获得最佳性能。
建议将此值设置为系统具有的物理CPU核心数(而不是逻辑核心数)。"""
num_predict: Optional[int] = None
"""生成文本时预测的最大标记数。
(默认值:128,-1 = 无限生成,-2 = 填充上下文)"""
repeat_last_n: Optional[int] = None
"""设置模型向后查看的距离,以防止重复。(默认值:64,0 = 禁用,-1 = num_ctx)"""
repeat_penalty: Optional[float] = None
"""设置对重复的惩罚程度。较高的值(例如1.5)会更严厉地惩罚重复,而较低的值(例如0.9)则会更宽容。(默认值:1.1)"""
temperature: Optional[float] = None
"""模型的温度。增加温度会使模型的回答更具创造性。(默认值:0.8)"""
stop: Optional[List[str]] = None
"""设置要使用的停止标记。"""
tfs_z: Optional[float] = None
"""尾部自由抽样用于减少输出中不太可能的标记的影响。较高的值(例如2.0)将减少影响,而值为1.0将禁用此设置。(默认值:1)"""
top_k: Optional[int] = None
"""减少生成无意义内容的概率。较高的值(例如100)会产生更多不同的答案,而较低的值(例如10)会更保守。(默认值:40)"""
top_p: Optional[float] = None
"""与top-k一起使用。较高的值(例如,0.95)会导致生成更多样化的文本,而较低的值(例如,0.5)会生成更加集中和保守的文本。(默认值:0.9)"""
system: Optional[str] = None
"""系统提示(覆盖了模型文件中定义的内容)"""
template: Optional[str] = None
"""完整提示或提示模板(覆盖在模型文件中定义的内容)"""
format: Optional[str] = None
"""指定输出的格式(例如,json)"""
timeout: Optional[int] = None
"""请求流的超时时间"""
keep_alive: Optional[Union[int, str]] = None
"""模型将保持加载到内存中的时间长短。
参数(默认值:5分钟)可以设置为:
1. 一个Golang中的持续时间字符串(如"10m"或"24h");
2. 以秒为单位的数字(如3600);
3. 任何负数,将使模型保持加载在内存中(例如-1或"-1m");
4. 0将在生成响应后立即卸载模型;
参见[Ollama文档](https://github.com/ollama/ollama/blob/main/docs/faq.md#how-do-i-keep-a-model-loaded-in-memory-or-make-it-unload-immediately)
"""
headers: Optional[dict] = None
"""传递到端点的其他标头(例如授权,引用者)。
这在Ollama托管在需要身份验证令牌的云服务上时非常有用。"""
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用Ollama时的默认参数。"""
return {
"model": self.model,
"format": self.format,
"options": {
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gpu": self.num_gpu,
"num_thread": self.num_thread,
"num_predict": self.num_predict,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"stop": self.stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
"system": self.system,
"template": self.template,
"keep_alive": self.keep_alive,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {**{"model": self.model, "format": self.format}, **self._default_params}
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {"prompt": prompt, "images": images}
yield from self._create_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate",
**kwargs,
)
async def _acreate_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
payload = {"prompt": prompt, "images": images}
async for item in self._acreate_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate",
**kwargs,
):
yield item
def _create_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
if "options" in kwargs:
params["options"] = kwargs["options"]
else:
params["options"] = {
**params["options"],
"stop": stop,
**{k: v for k, v in kwargs.items() if k not in self._default_params},
}
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
response = requests.post(
url=api_url,
headers={
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
json=request_payload,
stream=True,
timeout=self.timeout,
)
response.encoding = "utf-8"
if response.status_code != 200:
if response.status_code == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404. "
"Maybe your model is not found "
f"and you should pull the model with `ollama pull {self.model}`."
)
else:
optional_detail = response.text
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
return response.iter_lines(decode_unicode=True)
async def _acreate_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
if "options" in kwargs:
params["options"] = kwargs["options"]
else:
params["options"] = {
**params["options"],
"stop": stop,
**{k: v for k, v in kwargs.items() if k not in self._default_params},
}
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
async with aiohttp.ClientSession() as session:
async with session.post(
url=api_url,
headers={
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
json=request_payload,
timeout=self.timeout,
) as response:
if response.status != 200:
if response.status == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = response.text
raise ValueError(
f"Ollama call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
yield line.decode("utf-8")
def _stream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
async def _astream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
[docs]class Ollama(BaseLLM, _OllamaCommon):
"""Ollama在本地运行大型语言模型。
要使用,请按照 https://ollama.ai/ 上的说明操作。
示例:
.. code-block:: python
from langchain_community.llms import Ollama
ollama = Ollama(model="llama2")"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "ollama-llm"
def _generate( # type: ignore[override]
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""调用Ollama的生成端点。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停止词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = ollama("Tell me a joke.")
"""
# TODO: add caching here.
generations = []
for prompt in prompts:
final_chunk = super()._stream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations) # type: ignore[arg-type]
async def _agenerate( # type: ignore[override]
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""调用Ollama的生成端点。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停止词列表。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = ollama("Tell me a joke.")
"""
# TODO: add caching here.
generations = []
for prompt in prompts:
final_chunk = await super()._astream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager, # type: ignore[arg-type]
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations) # type: ignore[arg-type]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk