9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107 | class OllamaEmbedding(BaseEmbedding):
"""Ollama嵌入的类。"""
base_url: str = Field(description="Base url the model is hosted by Ollama")
model_name: str = Field(description="The Ollama model to use.")
embed_batch_size: int = Field(
default=DEFAULT_EMBED_BATCH_SIZE,
description="The batch size for embedding calls.",
gt=0,
lte=2048,
)
ollama_additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the Ollama API."
)
def __init__(
self,
model_name: str,
base_url: str = "http://localhost:11434",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
ollama_additional_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
super().__init__(
model_name=model_name,
base_url=base_url,
embed_batch_size=embed_batch_size,
ollama_additional_kwargs=ollama_additional_kwargs or {},
callback_manager=callback_manager,
)
@classmethod
def class_name(cls) -> str:
return "OllamaEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""获取查询嵌入。"""
return self.get_general_text_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""_get_query_embedding的异步版本。"""
return self.get_general_text_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""获取文本嵌入。"""
return self.get_general_text_embedding(text)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""异步获取文本嵌入。"""
return self.get_general_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""获取文本嵌入。"""
embeddings_list: List[List[float]] = []
for text in texts:
embeddings = self.get_general_text_embedding(text)
embeddings_list.append(embeddings)
return embeddings_list
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""异步获取文本嵌入。"""
return self._get_text_embeddings(texts)
def get_general_text_embedding(self, prompt: str) -> List[float]:
"""获取Ollama嵌入。"""
try:
import requests
except ImportError:
raise ImportError(
"Could not import requests library."
"Please install requests with `pip install requests`"
)
ollama_request_body = {
"prompt": prompt,
"model": self.model_name,
"options": self.ollama_additional_kwargs,
}
response = requests.post(
url=f"{self.base_url}/api/embeddings",
headers={"Content-Type": "application/json"},
json=ollama_request_body,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
try:
return response.json()["embedding"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised for Ollama Call: {e}.\nResponse: {response.text}"
)
|