Source code for langchain_community.embeddings.ollama
import logging
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra
logger = logging.getLogger(__name__)
[docs]class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama在本地运行大型语言模型。
要使用,请按照 https://ollama.ai/ 上的说明操作。
示例:
.. code-block:: python
from langchain_community.embeddings import OllamaEmbeddings
ollama_emb = OllamaEmbeddings(
model="llama:7b",
)
r1 = ollama_emb.embed_documents(
[
"Alpha是希腊字母表的第一个字母",
"Beta是希腊字母表的第二个字母",
]
)
r2 = ollama_emb.embed_query(
"希腊字母表的第二个字母是什么"
)"""
base_url: str = "http://localhost:11434"
"""模型托管的基本URL。"""
model: str = "llama2"
"""要使用的模型名称。"""
embed_instruction: str = "passage: "
"""用于嵌入文档的指令。"""
query_instruction: str = "query: "
"""用于嵌入查询的指令。"""
mirostat: Optional[int] = None
"""启用Mirostat采样以控制困惑度。
(默认值:0,0 = 禁用,1 = Mirostat,2 = Mirostat 2.0)"""
mirostat_eta: Optional[float] = None
"""影响算法对生成文本的反馈作出响应的速度。较低的学习率会导致调整速度较慢,而较高的学习率会使算法更具响应性。(默认值:0.1)"""
mirostat_tau: Optional[float] = None
"""控制输出的一致性和多样性之间的平衡。较低的值会导致更加聚焦和连贯的文本。(默认值:5.0)"""
num_ctx: Optional[int] = None
"""设置用于生成下一个标记的上下文窗口的大小。(默认值:2048)"""
num_gpu: Optional[int] = None
"""要使用的GPU数量。在macOS上,默认值为1,以启用Metal支持,为0则禁用。"""
num_thread: Optional[int] = None
"""设置计算过程中要使用的线程数。
默认情况下,Ollama会检测以获得最佳性能。
建议将此值设置为系统具有的物理CPU核心数(而不是逻辑核心数)。"""
repeat_last_n: Optional[int] = None
"""设置模型向后查看的距离,以防止重复。(默认值:64,0 = 禁用,-1 = num_ctx)"""
repeat_penalty: Optional[float] = None
"""设置对重复的惩罚程度。较高的值(例如1.5)会更严厉地惩罚重复,而较低的值(例如0.9)则会更宽容。(默认值:1.1)"""
temperature: Optional[float] = None
"""模型的温度。增加温度会使模型的回答更具创造性。(默认值:0.8)"""
stop: Optional[List[str]] = None
"""设置要使用的停止标记。"""
tfs_z: Optional[float] = None
"""尾部自由抽样用于减少输出中不太可能的标记的影响。较高的值(例如2.0)将减少影响,而值为1.0将禁用此设置。(默认值:1)"""
top_k: Optional[int] = None
"""减少生成无意义内容的概率。较高的值(例如100)会产生更多不同的答案,而较低的值(例如10)会更保守。(默认值:40)"""
top_p: Optional[float] = None
"""与top-k一起使用。较高的值(例如,0.95)会导致生成更多样化的文本,而较低的值(例如,0.5)会生成更加集中和保守的文本。(默认值:0.9)"""
show_progress: bool = False
"""是否显示tqdm进度条。必须安装`tqdm`。"""
headers: Optional[dict] = None
"""传递到端点的其他标头(例如授权,引用者)。
这在Ollama托管在需要身份验证令牌的云服务上时非常有用。"""
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用Ollama时的默认参数。"""
return {
"model": self.model,
"options": {
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gpu": self.num_gpu,
"num_thread": self.num_thread,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"stop": self.stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
}
model_kwargs: Optional[dict] = None
"""其他模型关键字参数"""
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {**{"model": self.model}, **self._default_params}
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
def _process_emb_response(self, input: str) -> List[float]:
"""处理来自API的响应。
参数:
response:来自API的响应。
返回:
作为字典的响应。
"""
headers = {
"Content-Type": "application/json",
**(self.headers or {}),
}
try:
res = requests.post(
f"{self.base_url}/api/embeddings",
headers=headers,
json={"model": self.model, "prompt": input, **self._default_params},
)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
if res.status_code != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (res.status_code, res.text)
)
try:
t = res.json()
return t["embedding"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {res.text}"
)
def _embed(self, input: List[str]) -> List[List[float]]:
if self.show_progress:
try:
from tqdm import tqdm
iter_ = tqdm(input, desc="OllamaEmbeddings")
except ImportError:
logger.warning(
"Unable to show progress bar because tqdm could not be imported. "
"Please install with `pip install tqdm`."
)
iter_ = input
else:
iter_ = input
return [self._process_emb_response(prompt) for prompt in iter_]
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""使用Ollama部署的嵌入模型嵌入文档。
参数:
texts:要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts]
embeddings = self._embed(instruction_pairs)
return embeddings
[docs] def embed_query(self, text: str) -> List[float]:
"""使用Ollama部署的嵌入模型嵌入一个查询。
参数:
text: 要嵌入的文本。
返回:
文本的嵌入。
"""
instruction_pair = f"{self.query_instruction}{text}"
embedding = self._embed([instruction_pair])[0]
return embedding